# PydanticAI > Agent Framework / shim to use Pydantic with LLMs PydanticAI is a Python agent framework designed to make it less painful to build production grade applications with Generative AI. # Concepts documentation ## Introduction Agents are PydanticAI's primary interface for interacting with LLMs. In some use cases a single Agent will control an entire application or component, but multiple agents can also interact to embody more complex workflows. The Agent class has full API documentation, but conceptually you can think of an agent as a container for: | **Component** | **Description** | | --- | --- | | [System prompt(s)](#system-prompts) | A set of instructions for the LLM written by the developer. | | [Function tool(s)](../tools/) | Functions that the LLM may call to get information while generating a response. | | [Structured output type](../output/) | The structured datatype the LLM must return at the end of a run, if specified. | | [Dependency type constraint](../dependencies/) | System prompt functions, tools, and output validators may all use dependencies when they're run. | | [LLM model](../api/models/base/) | Optional default LLM model associated with the agent. Can also be specified when running the agent. | | [Model Settings](#additional-configuration) | Optional default model settings to help fine tune requests. Can also be specified when running the agent. | In typing terms, agents are generic in their dependency and output types, e.g., an agent which required dependencies of type `Foobar` and produced outputs of type `list[str]` would have type `Agent[Foobar, list[str]]`. In practice, you shouldn't need to care about this, it should just mean your IDE can tell you when you have the right type, and if you choose to use [static type checking](#static-type-checking) it should work well with PydanticAI. Here's a toy example of an agent that simulates a roulette wheel: roulette_wheel.py ```python from pydantic_ai import Agent, RunContext roulette_agent = Agent( # (1)! 'openai:gpt-4o', deps_type=int, output_type=bool, system_prompt=( 'Use the `roulette_wheel` function to see if the ' 'customer has won based on the number they provide.' ), ) @roulette_agent.tool async def roulette_wheel(ctx: RunContext[int], square: int) -> str: # (2)! """check if the square is a winner""" return 'winner' if square == ctx.deps else 'loser' # Run the agent success_number = 18 # (3)! result = roulette_agent.run_sync('Put my money on square eighteen', deps=success_number) print(result.output) # (4)! #> True result = roulette_agent.run_sync('I bet five is the winner', deps=success_number) print(result.output) #> False ``` 1. Create an agent, which expects an integer dependency and produces a boolean output. This agent will have type `Agent[int, bool]`. 1. Define a tool that checks if the square is a winner. Here RunContext is parameterized with the dependency type `int`; if you got the dependency type wrong you'd get a typing error. 1. In reality, you might want to use a random number here e.g. `random.randint(0, 36)`. 1. `result.output` will be a boolean indicating if the square is a winner. Pydantic performs the output validation, and it'll be typed as a `bool` since its type is derived from the `output_type` generic parameter of the agent. Agents are designed for reuse, like FastAPI Apps Agents are intended to be instantiated once (frequently as module globals) and reused throughout your application, similar to a small FastAPI app or an APIRouter. ## Running Agents There are four ways to run an agent: 1. agent.run() — a coroutine which returns a RunResult containing a completed response. 1. agent.run_sync() — a plain, synchronous function which returns a RunResult containing a completed response (internally, this just calls `loop.run_until_complete(self.run())`). 1. agent.run_stream() — a coroutine which returns a StreamedRunResult, which contains methods to stream a response as an async iterable. 1. agent.iter() — a context manager which returns an AgentRun, an async-iterable over the nodes of the agent's underlying Graph. Here's a simple example demonstrating the first three: run_agent.py ```python from pydantic_ai import Agent agent = Agent('openai:gpt-4o') result_sync = agent.run_sync('What is the capital of Italy?') print(result_sync.output) #> Rome async def main(): result = await agent.run('What is the capital of France?') print(result.output) #> Paris async with agent.run_stream('What is the capital of the UK?') as response: print(await response.get_output()) #> London ``` *(This example is complete, it can be run "as is" — you'll need to add `asyncio.run(main())` to run `main`)* You can also pass messages from previous runs to continue a conversation or provide context, as described in [Messages and Chat History](../message-history/). ### Iterating Over an Agent's Graph Under the hood, each `Agent` in PydanticAI uses **pydantic-graph** to manage its execution flow. **pydantic-graph** is a generic, type-centric library for building and running finite state machines in Python. It doesn't actually depend on PydanticAI — you can use it standalone for workflows that have nothing to do with GenAI — but PydanticAI makes use of it to orchestrate the handling of model requests and model responses in an agent's run. In many scenarios, you don't need to worry about pydantic-graph at all; calling `agent.run(...)` simply traverses the underlying graph from start to finish. However, if you need deeper insight or control — for example to capture each tool invocation, or to inject your own logic at specific stages — PydanticAI exposes the lower-level iteration process via Agent.iter. This method returns an AgentRun, which you can async-iterate over, or manually drive node-by-node via the next method. Once the agent's graph returns an End, you have the final result along with a detailed history of all steps. #### `async for` iteration Here's an example of using `async for` with `iter` to record each node the agent executes: agent_iter_async_for.py ```python from pydantic_ai import Agent agent = Agent('openai:gpt-4o') async def main(): nodes = [] # Begin an AgentRun, which is an async-iterable over the nodes of the agent's graph async with agent.iter('What is the capital of France?') as agent_run: async for node in agent_run: # Each node represents a step in the agent's execution nodes.append(node) print(nodes) """ [ UserPromptNode( user_prompt='What is the capital of France?', instructions=None, instructions_functions=[], system_prompts=(), system_prompt_functions=[], system_prompt_dynamic_functions={}, ), ModelRequestNode( request=ModelRequest( parts=[ UserPromptPart( content='What is the capital of France?', timestamp=datetime.datetime(...), ) ] ) ), CallToolsNode( model_response=ModelResponse( parts=[TextPart(content='Paris')], usage=Usage( requests=1, request_tokens=56, response_tokens=1, total_tokens=57 ), model_name='gpt-4o', timestamp=datetime.datetime(...), ) ), End(data=FinalResult(output='Paris')), ] """ print(agent_run.result.output) #> Paris ``` - The `AgentRun` is an async iterator that yields each node (`BaseNode` or `End`) in the flow. - The run ends when an `End` node is returned. #### Using `.next(...)` manually You can also drive the iteration manually by passing the node you want to run next to the `AgentRun.next(...)` method. This allows you to inspect or modify the node before it executes or skip nodes based on your own logic, and to catch errors in `next()` more easily: agent_iter_next.py ```python from pydantic_ai import Agent from pydantic_graph import End agent = Agent('openai:gpt-4o') async def main(): async with agent.iter('What is the capital of France?') as agent_run: node = agent_run.next_node # (1)! all_nodes = [node] # Drive the iteration manually: while not isinstance(node, End): # (2)! node = await agent_run.next(node) # (3)! all_nodes.append(node) # (4)! print(all_nodes) """ [ UserPromptNode( user_prompt='What is the capital of France?', instructions=None, instructions_functions=[], system_prompts=(), system_prompt_functions=[], system_prompt_dynamic_functions={}, ), ModelRequestNode( request=ModelRequest( parts=[ UserPromptPart( content='What is the capital of France?', timestamp=datetime.datetime(...), ) ] ) ), CallToolsNode( model_response=ModelResponse( parts=[TextPart(content='Paris')], usage=Usage( requests=1, request_tokens=56, response_tokens=1, total_tokens=57, ), model_name='gpt-4o', timestamp=datetime.datetime(...), ) ), End(data=FinalResult(output='Paris')), ] """ ``` 1. We start by grabbing the first node that will be run in the agent's graph. 1. The agent run is finished once an `End` node has been produced; instances of `End` cannot be passed to `next`. 1. When you call `await agent_run.next(node)`, it executes that node in the agent's graph, updates the run's history, and returns the *next* node to run. 1. You could also inspect or mutate the new `node` here as needed. #### Accessing usage and the final output You can retrieve usage statistics (tokens, requests, etc.) at any time from the AgentRun object via `agent_run.usage()`. This method returns a Usage object containing the usage data. Once the run finishes, `agent_run.result` becomes a AgentRunResult object containing the final output (and related metadata). ______________________________________________________________________ ### Streaming Here is an example of streaming an agent run in combination with `async for` iteration: streaming.py ```python import asyncio from dataclasses import dataclass from datetime import date from pydantic_ai import Agent from pydantic_ai.messages import ( FinalResultEvent, FunctionToolCallEvent, FunctionToolResultEvent, PartDeltaEvent, PartStartEvent, TextPartDelta, ToolCallPartDelta, ) from pydantic_ai.tools import RunContext @dataclass class WeatherService: async def get_forecast(self, location: str, forecast_date: date) -> str: # In real code: call weather API, DB queries, etc. return f'The forecast in {location} on {forecast_date} is 24°C and sunny.' async def get_historic_weather(self, location: str, forecast_date: date) -> str: # In real code: call a historical weather API or DB return ( f'The weather in {location} on {forecast_date} was 18°C and partly cloudy.' ) weather_agent = Agent[WeatherService, str]( 'openai:gpt-4o', deps_type=WeatherService, output_type=str, # We'll produce a final answer as plain text system_prompt='Providing a weather forecast at the locations the user provides.', ) @weather_agent.tool async def weather_forecast( ctx: RunContext[WeatherService], location: str, forecast_date: date, ) -> str: if forecast_date >= date.today(): return await ctx.deps.get_forecast(location, forecast_date) else: return await ctx.deps.get_historic_weather(location, forecast_date) output_messages: list[str] = [] async def main(): user_prompt = 'What will the weather be like in Paris on Tuesday?' # Begin a node-by-node, streaming iteration async with weather_agent.iter(user_prompt, deps=WeatherService()) as run: async for node in run: if Agent.is_user_prompt_node(node): # A user prompt node => The user has provided input output_messages.append(f'=== UserPromptNode: {node.user_prompt} ===') elif Agent.is_model_request_node(node): # A model request node => We can stream tokens from the model's request output_messages.append( '=== ModelRequestNode: streaming partial request tokens ===' ) async with node.stream(run.ctx) as request_stream: async for event in request_stream: if isinstance(event, PartStartEvent): output_messages.append( f'[Request] Starting part {event.index}: {event.part!r}' ) elif isinstance(event, PartDeltaEvent): if isinstance(event.delta, TextPartDelta): output_messages.append( f'[Request] Part {event.index} text delta: {event.delta.content_delta!r}' ) elif isinstance(event.delta, ToolCallPartDelta): output_messages.append( f'[Request] Part {event.index} args_delta={event.delta.args_delta}' ) elif isinstance(event, FinalResultEvent): output_messages.append( f'[Result] The model produced a final output (tool_name={event.tool_name})' ) elif Agent.is_call_tools_node(node): # A handle-response node => The model returned some data, potentially calls a tool output_messages.append( '=== CallToolsNode: streaming partial response & tool usage ===' ) async with node.stream(run.ctx) as handle_stream: async for event in handle_stream: if isinstance(event, FunctionToolCallEvent): output_messages.append( f'[Tools] The LLM calls tool={event.part.tool_name!r} with args={event.part.args} (tool_call_id={event.part.tool_call_id!r})' ) elif isinstance(event, FunctionToolResultEvent): output_messages.append( f'[Tools] Tool call {event.tool_call_id!r} returned => {event.result.content}' ) elif Agent.is_end_node(node): assert run.result.output == node.data.output # Once an End node is reached, the agent run is complete output_messages.append( f'=== Final Agent Output: {run.result.output} ===' ) if __name__ == '__main__': asyncio.run(main()) print(output_messages) """ [ '=== UserPromptNode: What will the weather be like in Paris on Tuesday? ===', '=== ModelRequestNode: streaming partial request tokens ===', "[Request] Starting part 0: ToolCallPart(tool_name='weather_forecast', tool_call_id='0001')", '[Request] Part 0 args_delta={"location":"Pa', '[Request] Part 0 args_delta=ris","forecast_', '[Request] Part 0 args_delta=date":"2030-01-', '[Request] Part 0 args_delta=01"}', '=== CallToolsNode: streaming partial response & tool usage ===', '[Tools] The LLM calls tool=\'weather_forecast\' with args={"location":"Paris","forecast_date":"2030-01-01"} (tool_call_id=\'0001\')', "[Tools] Tool call '0001' returned => The forecast in Paris on 2030-01-01 is 24°C and sunny.", '=== ModelRequestNode: streaming partial request tokens ===', "[Request] Starting part 0: TextPart(content='It will be ')", '[Result] The model produced a final output (tool_name=None)', "[Request] Part 0 text delta: 'warm and sunny '", "[Request] Part 0 text delta: 'in Paris on '", "[Request] Part 0 text delta: 'Tuesday.'", '=== CallToolsNode: streaming partial response & tool usage ===', '=== Final Agent Output: It will be warm and sunny in Paris on Tuesday. ===', ] """ ``` ______________________________________________________________________ ### Additional Configuration #### Usage Limits PydanticAI offers a UsageLimits structure to help you limit your usage (tokens and/or requests) on model runs. You can apply these settings by passing the `usage_limits` argument to the `run{_sync,_stream}` functions. Consider the following example, where we limit the number of response tokens: ```py from pydantic_ai import Agent from pydantic_ai.exceptions import UsageLimitExceeded from pydantic_ai.usage import UsageLimits agent = Agent('anthropic:claude-3-5-sonnet-latest') result_sync = agent.run_sync( 'What is the capital of Italy? Answer with just the city.', usage_limits=UsageLimits(response_tokens_limit=10), ) print(result_sync.output) #> Rome print(result_sync.usage()) #> Usage(requests=1, request_tokens=62, response_tokens=1, total_tokens=63) try: result_sync = agent.run_sync( 'What is the capital of Italy? Answer with a paragraph.', usage_limits=UsageLimits(response_tokens_limit=10), ) except UsageLimitExceeded as e: print(e) #> Exceeded the response_tokens_limit of 10 (response_tokens=32) ``` Restricting the number of requests can be useful in preventing infinite loops or excessive tool calling: ```py from typing_extensions import TypedDict from pydantic_ai import Agent, ModelRetry from pydantic_ai.exceptions import UsageLimitExceeded from pydantic_ai.usage import UsageLimits class NeverOutputType(TypedDict): """ Never ever coerce data to this type. """ never_use_this: str agent = Agent( 'anthropic:claude-3-5-sonnet-latest', retries=3, output_type=NeverOutputType, system_prompt='Any time you get a response, call the `infinite_retry_tool` to produce another response.', ) @agent.tool_plain(retries=5) # (1)! def infinite_retry_tool() -> int: raise ModelRetry('Please try again.') try: result_sync = agent.run_sync( 'Begin infinite retry loop!', usage_limits=UsageLimits(request_limit=3) # (2)! ) except UsageLimitExceeded as e: print(e) #> The next request would exceed the request_limit of 3 ``` 1. This tool has the ability to retry 5 times before erroring, simulating a tool that might get stuck in a loop. 1. This run will error after 3 requests, preventing the infinite tool calling. Note This is especially relevant if you've registered many tools. The `request_limit` can be used to prevent the model from calling them in a loop too many times. #### Model (Run) Settings PydanticAI offers a settings.ModelSettings structure to help you fine tune your requests. This structure allows you to configure common parameters that influence the model's behavior, such as `temperature`, `max_tokens`, `timeout`, and more. There are two ways to apply these settings: 1. Passing to `run{_sync,_stream}` functions via the `model_settings` argument. This allows for fine-tuning on a per-request basis. 1. Setting during Agent initialization via the `model_settings` argument. These settings will be applied by default to all subsequent run calls using said agent. However, `model_settings` provided during a specific run call will override the agent's default settings. For example, if you'd like to set the `temperature` setting to `0.0` to ensure less random behavior, you can do the following: ```py from pydantic_ai import Agent agent = Agent('openai:gpt-4o') result_sync = agent.run_sync( 'What is the capital of Italy?', model_settings={'temperature': 0.0} ) print(result_sync.output) #> Rome ``` ### Model specific settings If you wish to further customize model behavior, you can use a subclass of ModelSettings, like GeminiModelSettings, associated with your model of choice. For example: ```py from pydantic_ai import Agent, UnexpectedModelBehavior from pydantic_ai.models.gemini import GeminiModelSettings agent = Agent('google-gla:gemini-1.5-flash') try: result = agent.run_sync( 'Write a list of 5 very rude things that I might say to the universe after stubbing my toe in the dark:', model_settings=GeminiModelSettings( temperature=0.0, # general model settings can also be specified gemini_safety_settings=[ { 'category': 'HARM_CATEGORY_HARASSMENT', 'threshold': 'BLOCK_LOW_AND_ABOVE', }, { 'category': 'HARM_CATEGORY_HATE_SPEECH', 'threshold': 'BLOCK_LOW_AND_ABOVE', }, ], ), ) except UnexpectedModelBehavior as e: print(e) # (1)! """ Safety settings triggered, body: """ ``` 1. This error is raised because the safety thresholds were exceeded. ## Runs vs. Conversations An agent **run** might represent an entire conversation — there's no limit to how many messages can be exchanged in a single run. However, a **conversation** might also be composed of multiple runs, especially if you need to maintain state between separate interactions or API calls. Here's an example of a conversation comprised of multiple runs: conversation_example.py ```python from pydantic_ai import Agent agent = Agent('openai:gpt-4o') # First run result1 = agent.run_sync('Who was Albert Einstein?') print(result1.output) #> Albert Einstein was a German-born theoretical physicist. # Second run, passing previous messages result2 = agent.run_sync( 'What was his most famous equation?', message_history=result1.new_messages(), # (1)! ) print(result2.output) #> Albert Einstein's most famous equation is (E = mc^2). ``` 1. Continue the conversation; without `message_history` the model would not know who "his" was referring to. *(This example is complete, it can be run "as is")* ## Type safe by design PydanticAI is designed to work well with static type checkers, like mypy and pyright. Typing is (somewhat) optional PydanticAI is designed to make type checking as useful as possible for you if you choose to use it, but you don't have to use types everywhere all the time. That said, because PydanticAI uses Pydantic, and Pydantic uses type hints as the definition for schema and validation, some types (specifically type hints on parameters to tools, and the `output_type` arguments to Agent) are used at runtime. We (the library developers) have messed up if type hints are confusing you more than helping you, if you find this, please create an [issue](https://github.com/pydantic/pydantic-ai/issues) explaining what's annoying you! In particular, agents are generic in both the type of their dependencies and the type of the outputs they return, so you can use the type hints to ensure you're using the right types. Consider the following script with type mistakes: type_mistakes.py ```python from dataclasses import dataclass from pydantic_ai import Agent, RunContext @dataclass class User: name: str agent = Agent( 'test', deps_type=User, # (1)! output_type=bool, ) @agent.system_prompt def add_user_name(ctx: RunContext[str]) -> str: # (2)! return f"The user's name is {ctx.deps}." def foobar(x: bytes) -> None: pass result = agent.run_sync('Does their name start with "A"?', deps=User('Anne')) foobar(result.output) # (3)! ``` 1. The agent is defined as expecting an instance of `User` as `deps`. 1. But here `add_user_name` is defined as taking a `str` as the dependency, not a `User`. 1. Since the agent is defined as returning a `bool`, this will raise a type error since `foobar` expects `bytes`. Running `mypy` on this will give the following output: ```bash ➤ uv run mypy type_mistakes.py type_mistakes.py:18: error: Argument 1 to "system_prompt" of "Agent" has incompatible type "Callable[[RunContext[str]], str]"; expected "Callable[[RunContext[User]], str]" [arg-type] type_mistakes.py:28: error: Argument 1 to "foobar" has incompatible type "bool"; expected "bytes" [arg-type] Found 2 errors in 1 file (checked 1 source file) ``` Running `pyright` would identify the same issues. ## System Prompts System prompts might seem simple at first glance since they're just strings (or sequences of strings that are concatenated), but crafting the right system prompt is key to getting the model to behave as you want. Tip For most use cases, you should use `instructions` instead of "system prompts". If you know what you are doing though and want to preserve system prompt messages in the message history sent to the LLM in subsequent completions requests, you can achieve this using the `system_prompt` argument/decorator. See the section below on [Instructions](#instructions) for more information. Generally, system prompts fall into two categories: 1. **Static system prompts**: These are known when writing the code and can be defined via the `system_prompt` parameter of the Agent constructor. 1. **Dynamic system prompts**: These depend in some way on context that isn't known until runtime, and should be defined via functions decorated with @agent.system_prompt. You can add both to a single agent; they're appended in the order they're defined at runtime. Here's an example using both types of system prompts: system_prompts.py ```python from datetime import date from pydantic_ai import Agent, RunContext agent = Agent( 'openai:gpt-4o', deps_type=str, # (1)! system_prompt="Use the customer's name while replying to them.", # (2)! ) @agent.system_prompt # (3)! def add_the_users_name(ctx: RunContext[str]) -> str: return f"The user's name is {ctx.deps}." @agent.system_prompt def add_the_date() -> str: # (4)! return f'The date is {date.today()}.' result = agent.run_sync('What is the date?', deps='Frank') print(result.output) #> Hello Frank, the date today is 2032-01-02. ``` 1. The agent expects a string dependency. 1. Static system prompt defined at agent creation time. 1. Dynamic system prompt defined via a decorator with RunContext, this is called just after `run_sync`, not when the agent is created, so can benefit from runtime information like the dependencies used on that run. 1. Another dynamic system prompt, system prompts don't have to have the `RunContext` parameter. *(This example is complete, it can be run "as is")* ## Instructions Instructions are similar to system prompts. The main difference is that when an explicit `message_history` is provided in a call to `Agent.run` and similar methods, *instructions* from any existing messages in the history are not included in the request to the model — only the instructions of the *current* agent are included. You should use: - `instructions` when you want your request to the model to only include system prompts for the *current* agent - `system_prompt` when you want your request to the model to *retain* the system prompts used in previous requests (possibly made using other agents) In general, we recommend using `instructions` instead of `system_prompt` unless you have a specific reason to use `system_prompt`. Instructions, like system prompts, fall into two categories: 1. **Static instructions**: These are known when writing the code and can be defined via the `instructions` parameter of the Agent constructor. 1. **Dynamic instructions**: These rely on context that is only available at runtime and should be defined using functions decorated with @agent.instructions. Unlike dynamic system prompts, which may be reused when `message_history` is present, dynamic instructions are always reevaluated. Both static and dynamic instructions can be added to a single agent, and they are appended in the order they are defined at runtime. Here's an example using both types of instructions: instructions.py ```python from datetime import date from pydantic_ai import Agent, RunContext agent = Agent( 'openai:gpt-4o', deps_type=str, # (1)! instructions="Use the customer's name while replying to them.", # (2)! ) @agent.instructions # (3)! def add_the_users_name(ctx: RunContext[str]) -> str: return f"The user's name is {ctx.deps}." @agent.instructions def add_the_date() -> str: # (4)! return f'The date is {date.today()}.' result = agent.run_sync('What is the date?', deps='Frank') print(result.output) #> Hello Frank, the date today is 2032-01-02. ``` 1. The agent expects a string dependency. 1. Static instructions defined at agent creation time. 1. Dynamic instructions defined via a decorator with RunContext, this is called just after `run_sync`, not when the agent is created, so can benefit from runtime information like the dependencies used on that run. 1. Another dynamic instruction, instructions don't have to have the `RunContext` parameter. *(This example is complete, it can be run "as is")* Note that returning an empty string will result in no instruction message added. ## Reflection and self-correction Validation errors from both function tool parameter validation and [structured output validation](../output/#structured-output) can be passed back to the model with a request to retry. You can also raise ModelRetry from within a [tool](../tools/) or [output validator function](../output/#output-validator-functions) to tell the model it should retry generating a response. - The default retry count is **1** but can be altered for the entire agent, a specific tool, or an output validator. - You can access the current retry count from within a tool or output validator via ctx.retry. Here's an example: tool_retry.py ```python from pydantic import BaseModel from pydantic_ai import Agent, RunContext, ModelRetry from fake_database import DatabaseConn class ChatResult(BaseModel): user_id: int message: str agent = Agent( 'openai:gpt-4o', deps_type=DatabaseConn, output_type=ChatResult, ) @agent.tool(retries=2) def get_user_by_name(ctx: RunContext[DatabaseConn], name: str) -> int: """Get a user's ID from their full name.""" print(name) #> John #> John Doe user_id = ctx.deps.users.get(name=name) if user_id is None: raise ModelRetry( f'No user found with name {name!r}, remember to provide their full name' ) return user_id result = agent.run_sync( 'Send a message to John Doe asking for coffee next week', deps=DatabaseConn() ) print(result.output) """ user_id=123 message='Hello John, would you be free for coffee sometime next week? Let me know what works for you!' """ ``` ## Model errors If models behave unexpectedly (e.g., the retry limit is exceeded, or their API returns `503`), agent runs will raise UnexpectedModelBehavior. In these cases, capture_run_messages can be used to access the messages exchanged during the run to help diagnose the issue. agent_model_errors.py ```python from pydantic_ai import Agent, ModelRetry, UnexpectedModelBehavior, capture_run_messages agent = Agent('openai:gpt-4o') @agent.tool_plain def calc_volume(size: int) -> int: # (1)! if size == 42: return size**3 else: raise ModelRetry('Please try again.') with capture_run_messages() as messages: # (2)! try: result = agent.run_sync('Please get me the volume of a box with size 6.') except UnexpectedModelBehavior as e: print('An error occurred:', e) #> An error occurred: Tool exceeded max retries count of 1 print('cause:', repr(e.__cause__)) #> cause: ModelRetry('Please try again.') print('messages:', messages) """ messages: [ ModelRequest( parts=[ UserPromptPart( content='Please get me the volume of a box with size 6.', timestamp=datetime.datetime(...), ) ] ), ModelResponse( parts=[ ToolCallPart( tool_name='calc_volume', args={'size': 6}, tool_call_id='pyd_ai_tool_call_id', ) ], usage=Usage( requests=1, request_tokens=62, response_tokens=4, total_tokens=66 ), model_name='gpt-4o', timestamp=datetime.datetime(...), ), ModelRequest( parts=[ RetryPromptPart( content='Please try again.', tool_name='calc_volume', tool_call_id='pyd_ai_tool_call_id', timestamp=datetime.datetime(...), ) ] ), ModelResponse( parts=[ ToolCallPart( tool_name='calc_volume', args={'size': 6}, tool_call_id='pyd_ai_tool_call_id', ) ], usage=Usage( requests=1, request_tokens=72, response_tokens=8, total_tokens=80 ), model_name='gpt-4o', timestamp=datetime.datetime(...), ), ] """ else: print(result.output) ``` 1. Define a tool that will raise `ModelRetry` repeatedly in this case. 1. capture_run_messages is used to capture the messages exchanged during the run. *(This example is complete, it can be run "as is")* Note If you call run, run_sync, or run_stream more than once within a single `capture_run_messages` context, `messages` will represent the messages exchanged during the first call only. # Common Tools PydanticAI ships with native tools that can be used to enhance your agent's capabilities. ## DuckDuckGo Search Tool The DuckDuckGo search tool allows you to search the web for information. It is built on top of the [DuckDuckGo API](https://github.com/deedy5/duckduckgo_search). ### Installation To use duckduckgo_search_tool, you need to install [`pydantic-ai-slim`](../install/#slim-install) with the `duckduckgo` optional group: ```bash pip install "pydantic-ai-slim[duckduckgo]" ``` ```bash uv add "pydantic-ai-slim[duckduckgo]" ``` ### Usage Here's an example of how you can use the DuckDuckGo search tool with an agent: duckduckgo_search.py ```py from pydantic_ai import Agent from pydantic_ai.common_tools.duckduckgo import duckduckgo_search_tool agent = Agent( 'openai:o3-mini', tools=[duckduckgo_search_tool()], system_prompt='Search DuckDuckGo for the given query and return the results.', ) result = agent.run_sync( 'Can you list the top five highest-grossing animated films of 2025?' ) print(result.output) """ I looked into several sources on animated box‐office performance in 2025, and while detailed rankings can shift as more money is tallied, multiple independent reports have already highlighted a couple of record‐breaking shows. For example: • Ne Zha 2 – News outlets (Variety, Wikipedia's "List of animated feature films of 2025", and others) have reported that this Chinese title not only became the highest‑grossing animated film of 2025 but also broke records as the highest‑grossing non‑English animated film ever. One article noted its run exceeded US$1.7 billion. • Inside Out 2 – According to data shared on Statista and in industry news, this Pixar sequel has been on pace to set new records (with some sources even noting it as the highest‑grossing animated film ever, as of January 2025). Beyond those two, some entertainment trade sites (for example, a Just Jared article titled "Top 10 Highest-Earning Animated Films at the Box Office Revealed") have begun listing a broader top‑10. Although full consolidated figures can sometimes differ by source and are updated daily during a box‑office run, many of the industry trackers have begun to single out five films as the biggest earners so far in 2025. Unfortunately, although multiple articles discuss the "top animated films" of 2025, there isn't yet a single, universally accepted list with final numbers that names the complete top five. (Box‑office rankings, especially mid‑year, can be fluid as films continue to add to their totals.) Based on what several sources note so far, the two undisputed leaders are: 1. Ne Zha 2 2. Inside Out 2 The remaining top spots (3–5) are reported by some outlets in their "Top‑10 Animated Films" lists for 2025 but the titles and order can vary depending on the source and the exact cut‑off date of the data. For the most up‑to‑date and detailed ranking (including the 3rd, 4th, and 5th highest‑grossing films), I recommend checking resources like: • Wikipedia's "List of animated feature films of 2025" page • Box‑office tracking sites (such as Box Office Mojo or The Numbers) • Trade articles like the one on Just Jared To summarize with what is clear from the current reporting: 1. Ne Zha 2 2. Inside Out 2 3–5. Other animated films (yet to be definitively finalized across all reporting outlets) If you're looking for a final, consensus list of the top five, it may be best to wait until the 2025 year‑end box‑office tallies are in or to consult a regularly updated entertainment industry source. Would you like help finding a current source or additional details on where to look for the complete updated list? """ ``` ## Tavily Search Tool Info Tavily is a paid service, but they have free credits to explore their product. You need to [sign up for an account](https://app.tavily.com/home) and get an API key to use the Tavily search tool. The Tavily search tool allows you to search the web for information. It is built on top of the [Tavily API](https://tavily.com/). ### Installation To use tavily_search_tool, you need to install [`pydantic-ai-slim`](../install/#slim-install) with the `tavily` optional group: ```bash pip install "pydantic-ai-slim[tavily]" ``` ```bash uv add "pydantic-ai-slim[tavily]" ``` ### Usage Here's an example of how you can use the Tavily search tool with an agent: tavily_search.py ```py import os from pydantic_ai.agent import Agent from pydantic_ai.common_tools.tavily import tavily_search_tool api_key = os.getenv('TAVILY_API_KEY') assert api_key is not None agent = Agent( 'openai:o3-mini', tools=[tavily_search_tool(api_key)], system_prompt='Search Tavily for the given query and return the results.', ) result = agent.run_sync('Tell me the top news in the GenAI world, give me links.') print(result.output) """ Here are some of the top recent news articles related to GenAI: 1. How CLEAR users can improve risk analysis with GenAI – Thomson Reuters Read more: https://legal.thomsonreuters.com/blog/how-clear-users-can-improve-risk-analysis-with-genai/ (This article discusses how CLEAR's new GenAI-powered tool streamlines risk analysis by quickly summarizing key information from various public data sources.) 2. TELUS Digital Survey Reveals Enterprise Employees Are Entering Sensitive Data Into AI Assistants More Than You Think – FT.com Read more: https://markets.ft.com/data/announce/detail?dockey=600-202502260645BIZWIRE_USPRX____20250226_BW490609-1 (This news piece highlights findings from a TELUS Digital survey showing that many enterprise employees use public GenAI tools and sometimes even enter sensitive data.) 3. The Essential Guide to Generative AI – Virtualization Review Read more: https://virtualizationreview.com/Whitepapers/2025/02/SNOWFLAKE-The-Essential-Guide-to-Generative-AI.aspx (This guide provides insights into how GenAI is revolutionizing enterprise strategies and productivity, with input from industry leaders.) Feel free to click on the links to dive deeper into each story! """ ``` # Dependencies PydanticAI uses a dependency injection system to provide data and services to your agent's [system prompts](../agents/#system-prompts), [tools](../tools/) and [output validators](../output/#output-validator-functions). Matching PydanticAI's design philosophy, our dependency system tries to use existing best practice in Python development rather than inventing esoteric "magic", this should make dependencies type-safe, understandable easier to test and ultimately easier to deploy in production. ## Defining Dependencies Dependencies can be any python type. While in simple cases you might be able to pass a single object as a dependency (e.g. an HTTP connection), dataclasses are generally a convenient container when your dependencies included multiple objects. Here's an example of defining an agent that requires dependencies. (**Note:** dependencies aren't actually used in this example, see [Accessing Dependencies](#accessing-dependencies) below) unused_dependencies.py ```python from dataclasses import dataclass import httpx from pydantic_ai import Agent @dataclass class MyDeps: # (1)! api_key: str http_client: httpx.AsyncClient agent = Agent( 'openai:gpt-4o', deps_type=MyDeps, # (2)! ) async def main(): async with httpx.AsyncClient() as client: deps = MyDeps('foobar', client) result = await agent.run( 'Tell me a joke.', deps=deps, # (3)! ) print(result.output) #> Did you hear about the toothpaste scandal? They called it Colgate. ``` 1. Define a dataclass to hold dependencies. 1. Pass the dataclass type to the `deps_type` argument of the Agent constructor. **Note**: we're passing the type here, NOT an instance, this parameter is not actually used at runtime, it's here so we can get full type checking of the agent. 1. When running the agent, pass an instance of the dataclass to the `deps` parameter. *(This example is complete, it can be run "as is" — you'll need to add `asyncio.run(main())` to run `main`)* ## Accessing Dependencies Dependencies are accessed through the RunContext type, this should be the first parameter of system prompt functions etc. system_prompt_dependencies.py ```python from dataclasses import dataclass import httpx from pydantic_ai import Agent, RunContext @dataclass class MyDeps: api_key: str http_client: httpx.AsyncClient agent = Agent( 'openai:gpt-4o', deps_type=MyDeps, ) @agent.system_prompt # (1)! async def get_system_prompt(ctx: RunContext[MyDeps]) -> str: # (2)! response = await ctx.deps.http_client.get( # (3)! 'https://example.com', headers={'Authorization': f'Bearer {ctx.deps.api_key}'}, # (4)! ) response.raise_for_status() return f'Prompt: {response.text}' async def main(): async with httpx.AsyncClient() as client: deps = MyDeps('foobar', client) result = await agent.run('Tell me a joke.', deps=deps) print(result.output) #> Did you hear about the toothpaste scandal? They called it Colgate. ``` 1. RunContext may optionally be passed to a system_prompt function as the only argument. 1. RunContext is parameterized with the type of the dependencies, if this type is incorrect, static type checkers will raise an error. 1. Access dependencies through the .deps attribute. 1. Access dependencies through the .deps attribute. *(This example is complete, it can be run "as is" — you'll need to add `asyncio.run(main())` to run `main`)* ### Asynchronous vs. Synchronous dependencies [System prompt functions](../agents/#system-prompts), [function tools](../tools/) and [output validators](../output/#output-validator-functions) are all run in the async context of an agent run. If these functions are not coroutines (e.g. `async def`) they are called with run_in_executor in a thread pool, it's therefore marginally preferable to use `async` methods where dependencies perform IO, although synchronous dependencies should work fine too. `run` vs. `run_sync` and Asynchronous vs. Synchronous dependencies Whether you use synchronous or asynchronous dependencies, is completely independent of whether you use `run` or `run_sync` — `run_sync` is just a wrapper around `run` and agents are always run in an async context. Here's the same example as above, but with a synchronous dependency: sync_dependencies.py ```python from dataclasses import dataclass import httpx from pydantic_ai import Agent, RunContext @dataclass class MyDeps: api_key: str http_client: httpx.Client # (1)! agent = Agent( 'openai:gpt-4o', deps_type=MyDeps, ) @agent.system_prompt def get_system_prompt(ctx: RunContext[MyDeps]) -> str: # (2)! response = ctx.deps.http_client.get( 'https://example.com', headers={'Authorization': f'Bearer {ctx.deps.api_key}'} ) response.raise_for_status() return f'Prompt: {response.text}' async def main(): deps = MyDeps('foobar', httpx.Client()) result = await agent.run( 'Tell me a joke.', deps=deps, ) print(result.output) #> Did you hear about the toothpaste scandal? They called it Colgate. ``` 1. Here we use a synchronous `httpx.Client` instead of an asynchronous `httpx.AsyncClient`. 1. To match the synchronous dependency, the system prompt function is now a plain function, not a coroutine. *(This example is complete, it can be run "as is" — you'll need to add `asyncio.run(main())` to run `main`)* ## Full Example As well as system prompts, dependencies can be used in [tools](../tools/) and [output validators](../output/#output-validator-functions). full_example.py ```python from dataclasses import dataclass import httpx from pydantic_ai import Agent, ModelRetry, RunContext @dataclass class MyDeps: api_key: str http_client: httpx.AsyncClient agent = Agent( 'openai:gpt-4o', deps_type=MyDeps, ) @agent.system_prompt async def get_system_prompt(ctx: RunContext[MyDeps]) -> str: response = await ctx.deps.http_client.get('https://example.com') response.raise_for_status() return f'Prompt: {response.text}' @agent.tool # (1)! async def get_joke_material(ctx: RunContext[MyDeps], subject: str) -> str: response = await ctx.deps.http_client.get( 'https://example.com#jokes', params={'subject': subject}, headers={'Authorization': f'Bearer {ctx.deps.api_key}'}, ) response.raise_for_status() return response.text @agent.output_validator # (2)! async def validate_output(ctx: RunContext[MyDeps], output: str) -> str: response = await ctx.deps.http_client.post( 'https://example.com#validate', headers={'Authorization': f'Bearer {ctx.deps.api_key}'}, params={'query': output}, ) if response.status_code == 400: raise ModelRetry(f'invalid response: {response.text}') response.raise_for_status() return output async def main(): async with httpx.AsyncClient() as client: deps = MyDeps('foobar', client) result = await agent.run('Tell me a joke.', deps=deps) print(result.output) #> Did you hear about the toothpaste scandal? They called it Colgate. ``` 1. To pass `RunContext` to a tool, use the tool decorator. 1. `RunContext` may optionally be passed to a output_validator function as the first argument. *(This example is complete, it can be run "as is" — you'll need to add `asyncio.run(main())` to run `main`)* ## Overriding Dependencies When testing agents, it's useful to be able to customise dependencies. While this can sometimes be done by calling the agent directly within unit tests, we can also override dependencies while calling application code which in turn calls the agent. This is done via the override method on the agent. joke_app.py ```python from dataclasses import dataclass import httpx from pydantic_ai import Agent, RunContext @dataclass class MyDeps: api_key: str http_client: httpx.AsyncClient async def system_prompt_factory(self) -> str: # (1)! response = await self.http_client.get('https://example.com') response.raise_for_status() return f'Prompt: {response.text}' joke_agent = Agent('openai:gpt-4o', deps_type=MyDeps) @joke_agent.system_prompt async def get_system_prompt(ctx: RunContext[MyDeps]) -> str: return await ctx.deps.system_prompt_factory() # (2)! async def application_code(prompt: str) -> str: # (3)! ... ... # now deep within application code we call our agent async with httpx.AsyncClient() as client: app_deps = MyDeps('foobar', client) result = await joke_agent.run(prompt, deps=app_deps) # (4)! return result.output ``` 1. Define a method on the dependency to make the system prompt easier to customise. 1. Call the system prompt factory from within the system prompt function. 1. Application code that calls the agent, in a real application this might be an API endpoint. 1. Call the agent from within the application code, in a real application this call might be deep within a call stack. Note `app_deps` here will NOT be used when deps are overridden. *(This example is complete, it can be run "as is")* test_joke_app.py ```python from joke_app import MyDeps, application_code, joke_agent class TestMyDeps(MyDeps): # (1)! async def system_prompt_factory(self) -> str: return 'test prompt' async def test_application_code(): test_deps = TestMyDeps('test_key', None) # (2)! with joke_agent.override(deps=test_deps): # (3)! joke = await application_code('Tell me a joke.') # (4)! assert joke.startswith('Did you hear about the toothpaste scandal?') ``` 1. Define a subclass of `MyDeps` in tests to customise the system prompt factory. 1. Create an instance of the test dependency, we don't need to pass an `http_client` here as it's not used. 1. Override the dependencies of the agent for the duration of the `with` block, `test_deps` will be used when the agent is run. 1. Now we can safely call our application code, the agent will use the overridden dependencies. ## Examples The following examples demonstrate how to use dependencies in PydanticAI: - [Weather Agent](../examples/weather-agent/) - [SQL Generation](../examples/sql-gen/) - [RAG](../examples/rag/) # Messages and chat history PydanticAI provides access to messages exchanged during an agent run. These messages can be used both to continue a coherent conversation, and to understand how an agent performed. ### Accessing Messages from Results After running an agent, you can access the messages exchanged during that run from the `result` object. Both RunResult (returned by Agent.run, Agent.run_sync) and StreamedRunResult (returned by Agent.run_stream) have the following methods: - all_messages(): returns all messages, including messages from prior runs. There's also a variant that returns JSON bytes, all_messages_json(). - new_messages(): returns only the messages from the current run. There's also a variant that returns JSON bytes, new_messages_json(). StreamedRunResult and complete messages On StreamedRunResult, the messages returned from these methods will only include the final result message once the stream has finished. E.g. you've awaited one of the following coroutines: - StreamedRunResult.stream() - StreamedRunResult.stream_text() - StreamedRunResult.stream_structured() - StreamedRunResult.get_output() **Note:** The final result message will NOT be added to result messages if you use .stream_text(delta=True) since in this case the result content is never built as one string. Example of accessing methods on a RunResult : run_result_messages.py ```python from pydantic_ai import Agent agent = Agent('openai:gpt-4o', system_prompt='Be a helpful assistant.') result = agent.run_sync('Tell me a joke.') print(result.output) #> Did you hear about the toothpaste scandal? They called it Colgate. # all messages from the run print(result.all_messages()) """ [ ModelRequest( parts=[ SystemPromptPart( content='Be a helpful assistant.', timestamp=datetime.datetime(...), ), UserPromptPart( content='Tell me a joke.', timestamp=datetime.datetime(...), ), ] ), ModelResponse( parts=[ TextPart( content='Did you hear about the toothpaste scandal? They called it Colgate.' ) ], usage=Usage(requests=1, request_tokens=60, response_tokens=12, total_tokens=72), model_name='gpt-4o', timestamp=datetime.datetime(...), ), ] """ ``` *(This example is complete, it can be run "as is")* Example of accessing methods on a StreamedRunResult : streamed_run_result_messages.py ```python from pydantic_ai import Agent agent = Agent('openai:gpt-4o', system_prompt='Be a helpful assistant.') async def main(): async with agent.run_stream('Tell me a joke.') as result: # incomplete messages before the stream finishes print(result.all_messages()) """ [ ModelRequest( parts=[ SystemPromptPart( content='Be a helpful assistant.', timestamp=datetime.datetime(...), ), UserPromptPart( content='Tell me a joke.', timestamp=datetime.datetime(...), ), ] ) ] """ async for text in result.stream_text(): print(text) #> Did you hear #> Did you hear about the toothpaste #> Did you hear about the toothpaste scandal? They called #> Did you hear about the toothpaste scandal? They called it Colgate. # complete messages once the stream finishes print(result.all_messages()) """ [ ModelRequest( parts=[ SystemPromptPart( content='Be a helpful assistant.', timestamp=datetime.datetime(...), ), UserPromptPart( content='Tell me a joke.', timestamp=datetime.datetime(...), ), ] ), ModelResponse( parts=[ TextPart( content='Did you hear about the toothpaste scandal? They called it Colgate.' ) ], usage=Usage(request_tokens=50, response_tokens=12, total_tokens=62), model_name='gpt-4o', timestamp=datetime.datetime(...), ), ] """ ``` *(This example is complete, it can be run "as is" — you'll need to add `asyncio.run(main())` to run `main`)* ### Using Messages as Input for Further Agent Runs The primary use of message histories in PydanticAI is to maintain context across multiple agent runs. To use existing messages in a run, pass them to the `message_history` parameter of Agent.run, Agent.run_sync or Agent.run_stream. If `message_history` is set and not empty, a new system prompt is not generated — we assume the existing message history includes a system prompt. Reusing messages in a conversation ```python from pydantic_ai import Agent agent = Agent('openai:gpt-4o', system_prompt='Be a helpful assistant.') result1 = agent.run_sync('Tell me a joke.') print(result1.output) #> Did you hear about the toothpaste scandal? They called it Colgate. result2 = agent.run_sync('Explain?', message_history=result1.new_messages()) print(result2.output) #> This is an excellent joke invented by Samuel Colvin, it needs no explanation. print(result2.all_messages()) """ [ ModelRequest( parts=[ SystemPromptPart( content='Be a helpful assistant.', timestamp=datetime.datetime(...), ), UserPromptPart( content='Tell me a joke.', timestamp=datetime.datetime(...), ), ] ), ModelResponse( parts=[ TextPart( content='Did you hear about the toothpaste scandal? They called it Colgate.' ) ], usage=Usage(requests=1, request_tokens=60, response_tokens=12, total_tokens=72), model_name='gpt-4o', timestamp=datetime.datetime(...), ), ModelRequest( parts=[ UserPromptPart( content='Explain?', timestamp=datetime.datetime(...), ) ] ), ModelResponse( parts=[ TextPart( content='This is an excellent joke invented by Samuel Colvin, it needs no explanation.' ) ], usage=Usage(requests=1, request_tokens=61, response_tokens=26, total_tokens=87), model_name='gpt-4o', timestamp=datetime.datetime(...), ), ] """ ``` *(This example is complete, it can be run "as is")* ## Storing and loading messages (to JSON) While maintaining conversation state in memory is enough for many applications, often times you may want to store the messages history of an agent run on disk or in a database. This might be for evals, for sharing data between Python and JavaScript/TypeScript, or any number of other use cases. The intended way to do this is using a `TypeAdapter`. We export ModelMessagesTypeAdapter that can be used for this, or you can create your own. Here's an example showing how: serialize messages to json ```python from pydantic_core import to_jsonable_python from pydantic_ai import Agent from pydantic_ai.messages import ModelMessagesTypeAdapter # (1)! agent = Agent('openai:gpt-4o', system_prompt='Be a helpful assistant.') result1 = agent.run_sync('Tell me a joke.') history_step_1 = result1.all_messages() as_python_objects = to_jsonable_python(history_step_1) # (2)! same_history_as_step_1 = ModelMessagesTypeAdapter.validate_python(as_python_objects) result2 = agent.run_sync( # (3)! 'Tell me a different joke.', message_history=same_history_as_step_1 ) ``` 1. Alternatively, you can create a `TypeAdapter` from scratch: ```python from pydantic import TypeAdapter from pydantic_ai.messages import ModelMessage ModelMessagesTypeAdapter = TypeAdapter(list[ModelMessage]) ``` 1. Alternatively you can serialize to/from JSON directly: ```python from pydantic_core import to_json ... as_json_objects = to_json(history_step_1) same_history_as_step_1 = ModelMessagesTypeAdapter.validate_json(as_json_objects) ``` 1. You can now continue the conversation with history `same_history_as_step_1` despite creating a new agent run. *(This example is complete, it can be run "as is")* ## Other ways of using messages Since messages are defined by simple dataclasses, you can manually create and manipulate, e.g. for testing. The message format is independent of the model used, so you can use messages in different agents, or the same agent with different models. In the example below, we reuse the message from the first agent run, which uses the `openai:gpt-4o` model, in a second agent run using the `google-gla:gemini-1.5-pro` model. Reusing messages with a different model ```python from pydantic_ai import Agent agent = Agent('openai:gpt-4o', system_prompt='Be a helpful assistant.') result1 = agent.run_sync('Tell me a joke.') print(result1.output) #> Did you hear about the toothpaste scandal? They called it Colgate. result2 = agent.run_sync( 'Explain?', model='google-gla:gemini-1.5-pro', message_history=result1.new_messages(), ) print(result2.output) #> This is an excellent joke invented by Samuel Colvin, it needs no explanation. print(result2.all_messages()) """ [ ModelRequest( parts=[ SystemPromptPart( content='Be a helpful assistant.', timestamp=datetime.datetime(...), ), UserPromptPart( content='Tell me a joke.', timestamp=datetime.datetime(...), ), ] ), ModelResponse( parts=[ TextPart( content='Did you hear about the toothpaste scandal? They called it Colgate.' ) ], usage=Usage(requests=1, request_tokens=60, response_tokens=12, total_tokens=72), model_name='gpt-4o', timestamp=datetime.datetime(...), ), ModelRequest( parts=[ UserPromptPart( content='Explain?', timestamp=datetime.datetime(...), ) ] ), ModelResponse( parts=[ TextPart( content='This is an excellent joke invented by Samuel Colvin, it needs no explanation.' ) ], usage=Usage(requests=1, request_tokens=61, response_tokens=26, total_tokens=87), model_name='gemini-1.5-pro', timestamp=datetime.datetime(...), ), ] """ ``` ## Processing Message History Sometimes you may want to modify the message history before it's sent to the model. This could be for privacy reasons (filtering out sensitive information), to save costs on tokens, to give less context to the LLM, or custom processing logic. PydanticAI provides a `history_processors` parameter on `Agent` that allows you to intercept and modify the message history before each model request. ### Usage The `history_processors` is a list of callables that take a list of ModelMessage and return a modified list of the same type. Each processor is applied in sequence, and processors can be either synchronous or asynchronous. simple_history_processor.py ```python from pydantic_ai import Agent from pydantic_ai.messages import ( ModelMessage, ModelRequest, ModelResponse, TextPart, UserPromptPart, ) def filter_responses(messages: list[ModelMessage]) -> list[ModelMessage]: """Remove all ModelResponse messages, keeping only ModelRequest messages.""" return [msg for msg in messages if isinstance(msg, ModelRequest)] # Create agent with history processor agent = Agent('openai:gpt-4o', history_processors=[filter_responses]) # Example: Create some conversation history message_history = [ ModelRequest(parts=[UserPromptPart(content='What is 2+2?')]), ModelResponse(parts=[TextPart(content='2+2 equals 4')]), # This will be filtered out ] # When you run the agent, the history processor will filter out ModelResponse messages # result = agent.run_sync('What about 3+3?', message_history=message_history) ``` #### Keep Only Recent Messages You can use the `history_processor` to only keep the recent messages: keep_recent_messages.py ```python from pydantic_ai import Agent from pydantic_ai.messages import ModelMessage async def keep_recent_messages(messages: list[ModelMessage]) -> list[ModelMessage]: """Keep only the last 5 messages to manage token usage.""" return messages[-5:] if len(messages) > 5 else messages agent = Agent('openai:gpt-4o', history_processors=[keep_recent_messages]) # Example: Even with a long conversation history, only the last 5 messages are sent to the model long_conversation_history: list[ModelMessage] = [] # Your long conversation history here # result = agent.run_sync('What did we discuss?', message_history=long_conversation_history) ``` #### `RunContext` parameter History processors can optionally accept a RunContext parameter to access additional information about the current run, such as dependencies, model information, and usage statistics: context_aware_processor.py ```python from pydantic_ai import Agent from pydantic_ai.messages import ModelMessage from pydantic_ai.tools import RunContext def context_aware_processor( ctx: RunContext[None], messages: list[ModelMessage], ) -> list[ModelMessage]: # Access current usage current_tokens = ctx.usage.total_tokens # Filter messages based on context if current_tokens > 1000: return messages[-3:] # Keep only recent messages when token usage is high return messages agent = Agent('openai:gpt-4o', history_processors=[context_aware_processor]) ``` This allows for more sophisticated message processing based on the current state of the agent run. #### Summarize Old Messages Use an LLM to summarize older messages to preserve context while reducing tokens. summarize_old_messages.py ```python from pydantic_ai import Agent from pydantic_ai.messages import ModelMessage # Use a cheaper model to summarize old messages. summarize_agent = Agent( 'openai:gpt-4o-mini', instructions=""" Summarize this conversation, omitting small talk and unrelated topics. Focus on the technical discussion and next steps. """, ) async def summarize_old_messages(messages: list[ModelMessage]) -> list[ModelMessage]: # Summarize the oldest 10 messages if len(messages) > 10: oldest_messages = messages[:10] summary = await summarize_agent.run(message_history=oldest_messages) # Return the last message and the summary return summary.new_messages() + messages[-1:] return messages agent = Agent('openai:gpt-4o', history_processors=[summarize_old_messages]) ``` ### Testing History Processors You can test what messages are actually sent to the model provider using FunctionModel: test_history_processor.py ```python import pytest from pydantic_ai import Agent from pydantic_ai.messages import ( ModelMessage, ModelRequest, ModelResponse, TextPart, UserPromptPart, ) from pydantic_ai.models.function import AgentInfo, FunctionModel @pytest.fixture def received_messages() -> list[ModelMessage]: return [] @pytest.fixture def function_model(received_messages: list[ModelMessage]) -> FunctionModel: def capture_model_function(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: # Capture the messages that the provider actually receives received_messages.clear() received_messages.extend(messages) return ModelResponse(parts=[TextPart(content='Provider response')]) return FunctionModel(capture_model_function) def test_history_processor(function_model: FunctionModel, received_messages: list[ModelMessage]): def filter_responses(messages: list[ModelMessage]) -> list[ModelMessage]: return [msg for msg in messages if isinstance(msg, ModelRequest)] agent = Agent(function_model, history_processors=[filter_responses]) message_history = [ ModelRequest(parts=[UserPromptPart(content='Question 1')]), ModelResponse(parts=[TextPart(content='Answer 1')]), ] agent.run_sync('Question 2', message_history=message_history) assert received_messages == [ ModelRequest(parts=[UserPromptPart(content='Question 1')]), ModelRequest(parts=[UserPromptPart(content='Question 2')]), ] ``` ### Multiple Processors You can also use multiple processors: multiple_history_processors.py ```python from pydantic_ai import Agent from pydantic_ai.messages import ModelMessage, ModelRequest def filter_responses(messages: list[ModelMessage]) -> list[ModelMessage]: return [msg for msg in messages if isinstance(msg, ModelRequest)] def summarize_old_messages(messages: list[ModelMessage]) -> list[ModelMessage]: return messages[-5:] agent = Agent('openai:gpt-4o', history_processors=[filter_responses, summarize_old_messages]) ``` In this case, the `filter_responses` processor will be applied first, and the `summarize_old_messages` processor will be applied second. ## Examples For a more complete example of using messages in conversations, see the [chat app](../examples/chat-app/) example. # Multi-agent Applications There are roughly four levels of complexity when building applications with PydanticAI: 1. Single agent workflows — what most of the `pydantic_ai` documentation covers 1. [Agent delegation](#agent-delegation) — agents using another agent via tools 1. [Programmatic agent hand-off](#programmatic-agent-hand-off) — one agent runs, then application code calls another agent 1. [Graph based control flow](../graph/) — for the most complex cases, a graph-based state machine can be used to control the execution of multiple agents Of course, you can combine multiple strategies in a single application. ## Agent delegation "Agent delegation" refers to the scenario where an agent delegates work to another agent, then takes back control when the delegate agent (the agent called from within a tool) finishes. If you want to hand off control to another agent completely, without coming back to the first agent, you can use an [output function](../output/#output-functions). Since agents are stateless and designed to be global, you do not need to include the agent itself in agent [dependencies](../dependencies/). You'll generally want to pass ctx.usage to the usage keyword argument of the delegate agent run so usage within that run counts towards the total usage of the parent agent run. Multiple models Agent delegation doesn't need to use the same model for each agent. If you choose to use different models within a run, calculating the monetary cost from the final result.usage() of the run will not be possible, but you can still use UsageLimits to avoid unexpected costs. agent_delegation_simple.py ```python from pydantic_ai import Agent, RunContext from pydantic_ai.usage import UsageLimits joke_selection_agent = Agent( # (1)! 'openai:gpt-4o', system_prompt=( 'Use the `joke_factory` to generate some jokes, then choose the best. ' 'You must return just a single joke.' ), ) joke_generation_agent = Agent( # (2)! 'google-gla:gemini-1.5-flash', output_type=list[str] ) @joke_selection_agent.tool async def joke_factory(ctx: RunContext[None], count: int) -> list[str]: r = await joke_generation_agent.run( # (3)! f'Please generate {count} jokes.', usage=ctx.usage, # (4)! ) return r.output # (5)! result = joke_selection_agent.run_sync( 'Tell me a joke.', usage_limits=UsageLimits(request_limit=5, total_tokens_limit=300), ) print(result.output) #> Did you hear about the toothpaste scandal? They called it Colgate. print(result.usage()) #> Usage(requests=3, request_tokens=204, response_tokens=24, total_tokens=228) ``` 1. The "parent" or controlling agent. 1. The "delegate" agent, which is called from within a tool of the parent agent. 1. Call the delegate agent from within a tool of the parent agent. 1. Pass the usage from the parent agent to the delegate agent so the final result.usage() includes the usage from both agents. 1. Since the function returns `list[str]`, and the `output_type` of `joke_generation_agent` is also `list[str]`, we can simply return `r.output` from the tool. *(This example is complete, it can be run "as is")* The control flow for this example is pretty simple and can be summarised as follows: ``` graph TD START --> joke_selection_agent joke_selection_agent --> joke_factory["joke_factory (tool)"] joke_factory --> joke_generation_agent joke_generation_agent --> joke_factory joke_factory --> joke_selection_agent joke_selection_agent --> END ``` ### Agent delegation and dependencies Generally the delegate agent needs to either have the same [dependencies](../dependencies/) as the calling agent, or dependencies which are a subset of the calling agent's dependencies. Initializing dependencies We say "generally" above since there's nothing to stop you initializing dependencies within a tool call and therefore using interdependencies in a delegate agent that are not available on the parent, this should often be avoided since it can be significantly slower than reusing connections etc. from the parent agent. agent_delegation_deps.py ```python from dataclasses import dataclass import httpx from pydantic_ai import Agent, RunContext @dataclass class ClientAndKey: # (1)! http_client: httpx.AsyncClient api_key: str joke_selection_agent = Agent( 'openai:gpt-4o', deps_type=ClientAndKey, # (2)! system_prompt=( 'Use the `joke_factory` tool to generate some jokes on the given subject, ' 'then choose the best. You must return just a single joke.' ), ) joke_generation_agent = Agent( 'gemini-1.5-flash', deps_type=ClientAndKey, # (4)! output_type=list[str], system_prompt=( 'Use the "get_jokes" tool to get some jokes on the given subject, ' 'then extract each joke into a list.' ), ) @joke_selection_agent.tool async def joke_factory(ctx: RunContext[ClientAndKey], count: int) -> list[str]: r = await joke_generation_agent.run( f'Please generate {count} jokes.', deps=ctx.deps, # (3)! usage=ctx.usage, ) return r.output @joke_generation_agent.tool # (5)! async def get_jokes(ctx: RunContext[ClientAndKey], count: int) -> str: response = await ctx.deps.http_client.get( 'https://example.com', params={'count': count}, headers={'Authorization': f'Bearer {ctx.deps.api_key}'}, ) response.raise_for_status() return response.text async def main(): async with httpx.AsyncClient() as client: deps = ClientAndKey(client, 'foobar') result = await joke_selection_agent.run('Tell me a joke.', deps=deps) print(result.output) #> Did you hear about the toothpaste scandal? They called it Colgate. print(result.usage()) # (6)! #> Usage(requests=4, request_tokens=309, response_tokens=32, total_tokens=341) ``` 1. Define a dataclass to hold the client and API key dependencies. 1. Set the `deps_type` of the calling agent — `joke_selection_agent` here. 1. Pass the dependencies to the delegate agent's run method within the tool call. 1. Also set the `deps_type` of the delegate agent — `joke_generation_agent` here. 1. Define a tool on the delegate agent that uses the dependencies to make an HTTP request. 1. Usage now includes 4 requests — 2 from the calling agent and 2 from the delegate agent. *(This example is complete, it can be run "as is" — you'll need to add `asyncio.run(main())` to run `main`)* This example shows how even a fairly simple agent delegation can lead to a complex control flow: ``` graph TD START --> joke_selection_agent joke_selection_agent --> joke_factory["joke_factory (tool)"] joke_factory --> joke_generation_agent joke_generation_agent --> get_jokes["get_jokes (tool)"] get_jokes --> http_request["HTTP request"] http_request --> get_jokes get_jokes --> joke_generation_agent joke_generation_agent --> joke_factory joke_factory --> joke_selection_agent joke_selection_agent --> END ``` ## Programmatic agent hand-off "Programmatic agent hand-off" refers to the scenario where multiple agents are called in succession, with application code and/or a human in the loop responsible for deciding which agent to call next. Here agents don't need to use the same deps. Here we show two agents used in succession, the first to find a flight and the second to extract the user's seat preference. programmatic_handoff.py ```python from typing import Literal, Union from pydantic import BaseModel, Field from rich.prompt import Prompt from pydantic_ai import Agent, RunContext from pydantic_ai.messages import ModelMessage from pydantic_ai.usage import Usage, UsageLimits class FlightDetails(BaseModel): flight_number: str class Failed(BaseModel): """Unable to find a satisfactory choice.""" flight_search_agent = Agent[None, Union[FlightDetails, Failed]]( # (1)! 'openai:gpt-4o', output_type=Union[FlightDetails, Failed], # type: ignore system_prompt=( 'Use the "flight_search" tool to find a flight ' 'from the given origin to the given destination.' ), ) @flight_search_agent.tool # (2)! async def flight_search( ctx: RunContext[None], origin: str, destination: str ) -> Union[FlightDetails, None]: # in reality, this would call a flight search API or # use a browser to scrape a flight search website return FlightDetails(flight_number='AK456') usage_limits = UsageLimits(request_limit=15) # (3)! async def find_flight(usage: Usage) -> Union[FlightDetails, None]: # (4)! message_history: Union[list[ModelMessage], None] = None for _ in range(3): prompt = Prompt.ask( 'Where would you like to fly from and to?', ) result = await flight_search_agent.run( prompt, message_history=message_history, usage=usage, usage_limits=usage_limits, ) if isinstance(result.output, FlightDetails): return result.output else: message_history = result.all_messages( output_tool_return_content='Please try again.' ) class SeatPreference(BaseModel): row: int = Field(ge=1, le=30) seat: Literal['A', 'B', 'C', 'D', 'E', 'F'] # This agent is responsible for extracting the user's seat selection seat_preference_agent = Agent[None, Union[SeatPreference, Failed]]( # (5)! 'openai:gpt-4o', output_type=Union[SeatPreference, Failed], # type: ignore system_prompt=( "Extract the user's seat preference. " 'Seats A and F are window seats. ' 'Row 1 is the front row and has extra leg room. ' 'Rows 14, and 20 also have extra leg room. ' ), ) async def find_seat(usage: Usage) -> SeatPreference: # (6)! message_history: Union[list[ModelMessage], None] = None while True: answer = Prompt.ask('What seat would you like?') result = await seat_preference_agent.run( answer, message_history=message_history, usage=usage, usage_limits=usage_limits, ) if isinstance(result.output, SeatPreference): return result.output else: print('Could not understand seat preference. Please try again.') message_history = result.all_messages() async def main(): # (7)! usage: Usage = Usage() opt_flight_details = await find_flight(usage) if opt_flight_details is not None: print(f'Flight found: {opt_flight_details.flight_number}') #> Flight found: AK456 seat_preference = await find_seat(usage) print(f'Seat preference: {seat_preference}') #> Seat preference: row=1 seat='A' ``` 1. Define the first agent, which finds a flight. We use an explicit type annotation until [PEP-747](https://peps.python.org/pep-0747/) lands, see [structured output](../output/#structured-output). We use a union as the output type so the model can communicate if it's unable to find a satisfactory choice; internally, each member of the union will be registered as a separate tool. 1. Define a tool on the agent to find a flight. In this simple case we could dispense with the tool and just define the agent to return structured data, then search for a flight, but in more complex scenarios the tool would be necessary. 1. Define usage limits for the entire app. 1. Define a function to find a flight, which asks the user for their preferences and then calls the agent to find a flight. 1. As with `flight_search_agent` above, we use an explicit type annotation to define the agent. 1. Define a function to find the user's seat preference, which asks the user for their seat preference and then calls the agent to extract the seat preference. 1. Now that we've put our logic for running each agent into separate functions, our main app becomes very simple. *(This example is complete, it can be run "as is" — you'll need to add `asyncio.run(main())` to run `main`)* The control flow for this example can be summarised as follows: ``` graph TB START --> ask_user_flight["ask user for flight"] subgraph find_flight flight_search_agent --> ask_user_flight ask_user_flight --> flight_search_agent end flight_search_agent --> ask_user_seat["ask user for seat"] flight_search_agent --> END subgraph find_seat seat_preference_agent --> ask_user_seat ask_user_seat --> seat_preference_agent end seat_preference_agent --> END ``` ## Pydantic Graphs See the [graph](../graph/) documentation on when and how to use graphs. ## Examples The following examples demonstrate how to use dependencies in PydanticAI: - [Flight booking](../examples/flight-booking/) # Function Tools Function tools provide a mechanism for models to retrieve extra information to help them generate a response. They're useful when you want to enable the model to take some action and use the result, when it is impractical or impossible to put all the context an agent might need into the system prompt, or when you want to make agents' behavior more deterministic or reliable by deferring some of the logic required to generate a response to another (not necessarily AI-powered) tool. If you want a model to be able to call a function as its final action, without the result being sent back to the model, you can use an [output function](../output/#output-functions) instead. Function tools vs. RAG Function tools are basically the "R" of RAG (Retrieval-Augmented Generation) — they augment what the model can do by letting it request extra information. The main semantic difference between PydanticAI Tools and RAG is RAG is synonymous with vector search, while PydanticAI tools are more general-purpose. (Note: we may add support for vector search functionality in the future, particularly an API for generating embeddings. See [#58](https://github.com/pydantic/pydantic-ai/issues/58)) There are a number of ways to register tools with an agent: - via the @agent.tool decorator — for tools that need access to the agent context - via the @agent.tool_plain decorator — for tools that do not need access to the agent context - via the tools keyword argument to `Agent` which can take either plain functions, or instances of Tool ## Registering Function Tools via Decorator `@agent.tool` is considered the default decorator since in the majority of cases tools will need access to the agent context. Here's an example using both: dice_game.py ```python import random from pydantic_ai import Agent, RunContext agent = Agent( 'google-gla:gemini-1.5-flash', # (1)! deps_type=str, # (2)! system_prompt=( "You're a dice game, you should roll the die and see if the number " "you get back matches the user's guess. If so, tell them they're a winner. " "Use the player's name in the response." ), ) @agent.tool_plain # (3)! def roll_dice() -> str: """Roll a six-sided die and return the result.""" return str(random.randint(1, 6)) @agent.tool # (4)! def get_player_name(ctx: RunContext[str]) -> str: """Get the player's name.""" return ctx.deps dice_result = agent.run_sync('My guess is 4', deps='Anne') # (5)! print(dice_result.output) #> Congratulations Anne, you guessed correctly! You're a winner! ``` 1. This is a pretty simple task, so we can use the fast and cheap Gemini flash model. 1. We pass the user's name as the dependency, to keep things simple we use just the name as a string as the dependency. 1. This tool doesn't need any context, it just returns a random number. You could probably use a dynamic system prompt in this case. 1. This tool needs the player's name, so it uses `RunContext` to access dependencies which are just the player's name in this case. 1. Run the agent, passing the player's name as the dependency. *(This example is complete, it can be run "as is")* Let's print the messages from that game to see what happened: dice_game_messages.py ```python from dice_game import dice_result print(dice_result.all_messages()) """ [ ModelRequest( parts=[ SystemPromptPart( content="You're a dice game, you should roll the die and see if the number you get back matches the user's guess. If so, tell them they're a winner. Use the player's name in the response.", timestamp=datetime.datetime(...), ), UserPromptPart( content='My guess is 4', timestamp=datetime.datetime(...), ), ] ), ModelResponse( parts=[ ToolCallPart( tool_name='roll_dice', args={}, tool_call_id='pyd_ai_tool_call_id' ) ], usage=Usage(requests=1, request_tokens=90, response_tokens=2, total_tokens=92), model_name='gemini-1.5-flash', timestamp=datetime.datetime(...), ), ModelRequest( parts=[ ToolReturnPart( tool_name='roll_dice', content='4', tool_call_id='pyd_ai_tool_call_id', timestamp=datetime.datetime(...), ) ] ), ModelResponse( parts=[ ToolCallPart( tool_name='get_player_name', args={}, tool_call_id='pyd_ai_tool_call_id' ) ], usage=Usage(requests=1, request_tokens=91, response_tokens=4, total_tokens=95), model_name='gemini-1.5-flash', timestamp=datetime.datetime(...), ), ModelRequest( parts=[ ToolReturnPart( tool_name='get_player_name', content='Anne', tool_call_id='pyd_ai_tool_call_id', timestamp=datetime.datetime(...), ) ] ), ModelResponse( parts=[ TextPart( content="Congratulations Anne, you guessed correctly! You're a winner!" ) ], usage=Usage( requests=1, request_tokens=92, response_tokens=12, total_tokens=104 ), model_name='gemini-1.5-flash', timestamp=datetime.datetime(...), ), ] """ ``` We can represent this with a diagram: ``` sequenceDiagram participant Agent participant LLM Note over Agent: Send prompts Agent ->> LLM: System: "You're a dice game..."
User: "My guess is 4" activate LLM Note over LLM: LLM decides to use
a tool LLM ->> Agent: Call tool
roll_dice() deactivate LLM activate Agent Note over Agent: Rolls a six-sided die Agent -->> LLM: ToolReturn
"4" deactivate Agent activate LLM Note over LLM: LLM decides to use
another tool LLM ->> Agent: Call tool
get_player_name() deactivate LLM activate Agent Note over Agent: Retrieves player name Agent -->> LLM: ToolReturn
"Anne" deactivate Agent activate LLM Note over LLM: LLM constructs final response LLM ->> Agent: ModelResponse
"Congratulations Anne, ..." deactivate LLM Note over Agent: Game session complete ``` ## Registering Function Tools via Agent Argument As well as using the decorators, we can register tools via the `tools` argument to the Agent constructor. This is useful when you want to reuse tools, and can also give more fine-grained control over the tools. dice_game_tool_kwarg.py ```python import random from pydantic_ai import Agent, RunContext, Tool system_prompt = """\ You're a dice game, you should roll the die and see if the number you get back matches the user's guess. If so, tell them they're a winner. Use the player's name in the response. """ def roll_dice() -> str: """Roll a six-sided die and return the result.""" return str(random.randint(1, 6)) def get_player_name(ctx: RunContext[str]) -> str: """Get the player's name.""" return ctx.deps agent_a = Agent( 'google-gla:gemini-1.5-flash', deps_type=str, tools=[roll_dice, get_player_name], # (1)! system_prompt=system_prompt, ) agent_b = Agent( 'google-gla:gemini-1.5-flash', deps_type=str, tools=[ # (2)! Tool(roll_dice, takes_ctx=False), Tool(get_player_name, takes_ctx=True), ], system_prompt=system_prompt, ) dice_result = {} dice_result['a'] = agent_a.run_sync('My guess is 6', deps='Yashar') dice_result['b'] = agent_b.run_sync('My guess is 4', deps='Anne') print(dice_result['a'].output) #> Tough luck, Yashar, you rolled a 4. Better luck next time. print(dice_result['b'].output) #> Congratulations Anne, you guessed correctly! You're a winner! ``` 1. The simplest way to register tools via the `Agent` constructor is to pass a list of functions, the function signature is inspected to determine if the tool takes RunContext. 1. `agent_a` and `agent_b` are identical — but we can use Tool to reuse tool definitions and give more fine-grained control over how tools are defined, e.g. setting their name or description, or using a custom [`prepare`](#tool-prepare) method. *(This example is complete, it can be run "as is")* ## Function Tool Output Tools can return anything that Pydantic can serialize to JSON, as well as audio, video, image or document content depending on the types of [multi-modal input](../input/) the model supports: function_tool_output.py ```python from datetime import datetime from pydantic import BaseModel from pydantic_ai import Agent, DocumentUrl, ImageUrl from pydantic_ai.models.openai import OpenAIResponsesModel class User(BaseModel): name: str age: int agent = Agent(model=OpenAIResponsesModel('gpt-4o')) @agent.tool_plain def get_current_time() -> datetime: return datetime.now() @agent.tool_plain def get_user() -> User: return User(name='John', age=30) @agent.tool_plain def get_company_logo() -> ImageUrl: return ImageUrl(url='https://iili.io/3Hs4FMg.png') @agent.tool_plain def get_document() -> DocumentUrl: return DocumentUrl(url='https://www.w3.org/WAI/ER/tests/xhtml/testfiles/resources/pdf/dummy.pdf') result = agent.run_sync('What time is it?') print(result.output) #> The current time is 10:45 PM on April 17, 2025. result = agent.run_sync('What is the user name?') print(result.output) #> The user's name is John. result = agent.run_sync('What is the company name in the logo?') print(result.output) #> The company name in the logo is "Pydantic." result = agent.run_sync('What is the main content of the document?') print(result.output) #> The document contains just the text "Dummy PDF file." ``` *(This example is complete, it can be run "as is")* Some models (e.g. Gemini) natively support semi-structured return values, while some expect text (OpenAI) but seem to be just as good at extracting meaning from the data. If a Python object is returned and the model expects a string, the value will be serialized to JSON. ### Advanced Tool Returns For scenarios where you need more control over both the tool's return value and the content sent to the model, you can use ToolReturn. This is particularly useful when you want to: - Provide rich multi-modal content (images, documents, etc.) to the model as context - Separate the programmatic return value from the model's context - Include additional metadata that shouldn't be sent to the LLM Here's an example of a computer automation tool that captures screenshots and provides visual feedback: advanced_tool_return.py ```python import time from pydantic_ai import Agent from pydantic_ai.messages import ToolReturn, BinaryContent agent = Agent('openai:gpt-4o') @agent.tool_plain def click_and_capture(x: int, y: int) -> ToolReturn: """Click at coordinates and show before/after screenshots.""" # Take screenshot before action before_screenshot = capture_screen() # Perform click operation perform_click(x, y) time.sleep(0.5) # Wait for UI to update # Take screenshot after action after_screenshot = capture_screen() return ToolReturn( return_value=f"Successfully clicked at ({x}, {y})", content=[ f"Clicked at coordinates ({x}, {y}). Here's the comparison:", "Before:", BinaryContent(data=before_screenshot, media_type="image/png"), "After:", BinaryContent(data=after_screenshot, media_type="image/png"), "Please analyze the changes and suggest next steps." ], metadata={ "coordinates": {"x": x, "y": y}, "action_type": "click_and_capture", "timestamp": time.time() } ) # The model receives the rich visual content for analysis # while your application can access the structured return_value and metadata result = agent.run_sync("Click on the submit button and tell me what happened") print(result.output) # The model can analyze the screenshots and provide detailed feedback ``` - **`return_value`**: The actual return value used in the tool response. This is what gets serialized and sent back to the model as the tool's result. - **`content`**: A sequence of content (text, images, documents, etc.) that provides additional context to the model. This appears as a separate user message. - **`metadata`**: Optional metadata that your application can access but is not sent to the LLM. Useful for logging, debugging, or additional processing. Some other AI frameworks call this feature "artifacts". This separation allows you to provide rich context to the model while maintaining clean, structured return values for your application logic. ## Function Tools vs. Structured Outputs As the name suggests, function tools use the model's "tools" or "functions" API to let the model know what is available to call. Tools or functions are also used to define the schema(s) for structured responses, thus a model might have access to many tools, some of which call function tools while others end the run and produce a final output. ## Function tools and schema Function parameters are extracted from the function signature, and all parameters except `RunContext` are used to build the schema for that tool call. Even better, PydanticAI extracts the docstring from functions and (thanks to [griffe](https://mkdocstrings.github.io/griffe/)) extracts parameter descriptions from the docstring and adds them to the schema. [Griffe supports](https://mkdocstrings.github.io/griffe/reference/docstrings/#docstrings) extracting parameter descriptions from `google`, `numpy`, and `sphinx` style docstrings. PydanticAI will infer the format to use based on the docstring, but you can explicitly set it using docstring_format. You can also enforce parameter requirements by setting `require_parameter_descriptions=True`. This will raise a UserError if a parameter description is missing. To demonstrate a tool's schema, here we use FunctionModel to print the schema a model would receive: tool_schema.py ```python from pydantic_ai import Agent from pydantic_ai.messages import ModelMessage, ModelResponse, TextPart from pydantic_ai.models.function import AgentInfo, FunctionModel agent = Agent() @agent.tool_plain(docstring_format='google', require_parameter_descriptions=True) def foobar(a: int, b: str, c: dict[str, list[float]]) -> str: """Get me foobar. Args: a: apple pie b: banana cake c: carrot smoothie """ return f'{a} {b} {c}' def print_schema(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: tool = info.function_tools[0] print(tool.description) #> Get me foobar. print(tool.parameters_json_schema) """ { 'additionalProperties': False, 'properties': { 'a': {'description': 'apple pie', 'type': 'integer'}, 'b': {'description': 'banana cake', 'type': 'string'}, 'c': { 'additionalProperties': {'items': {'type': 'number'}, 'type': 'array'}, 'description': 'carrot smoothie', 'type': 'object', }, }, 'required': ['a', 'b', 'c'], 'type': 'object', } """ return ModelResponse(parts=[TextPart('foobar')]) agent.run_sync('hello', model=FunctionModel(print_schema)) ``` *(This example is complete, it can be run "as is")* If a tool has a single parameter that can be represented as an object in JSON schema (e.g. dataclass, TypedDict, pydantic model), the schema for the tool is simplified to be just that object. Here's an example where we use TestModel.last_model_request_parameters to inspect the tool schema that would be passed to the model. single_parameter_tool.py ```python from pydantic import BaseModel from pydantic_ai import Agent from pydantic_ai.models.test import TestModel agent = Agent() class Foobar(BaseModel): """This is a Foobar""" x: int y: str z: float = 3.14 @agent.tool_plain def foobar(f: Foobar) -> str: return str(f) test_model = TestModel() result = agent.run_sync('hello', model=test_model) print(result.output) #> {"foobar":"x=0 y='a' z=3.14"} print(test_model.last_model_request_parameters.function_tools) """ [ ToolDefinition( name='foobar', description='This is a Foobar', parameters_json_schema={ 'properties': { 'x': {'type': 'integer'}, 'y': {'type': 'string'}, 'z': {'default': 3.14, 'type': 'number'}, }, 'required': ['x', 'y'], 'title': 'Foobar', 'type': 'object', }, ) ] """ ``` *(This example is complete, it can be run "as is")* If you have a function that lacks appropriate documentation (i.e. poorly named, no type information, poor docstring, use of \*args or \*\*kwargs and suchlike) then you can still turn it into a tool that can be effectively used by the agent with the `Tool.from_schema` function. With this you provide the name, description and JSON schema for the function directly: ```python from pydantic_ai import Agent, Tool from pydantic_ai.models.test import TestModel def foobar(**kwargs) -> str: return kwargs['a'] + kwargs['b'] tool = Tool.from_schema( function=foobar, name='sum', description='Sum two numbers.', json_schema={ 'additionalProperties': False, 'properties': { 'a': {'description': 'the first number', 'type': 'integer'}, 'b': {'description': 'the second number', 'type': 'integer'}, }, 'required': ['a', 'b'], 'type': 'object', } ) test_model = TestModel() agent = Agent(test_model, tools=[tool]) result = agent.run_sync('testing...') print(result.output) #> {"sum":0} ``` Please note that validation of the tool arguments will not be performed, and this will pass all arguments as keyword arguments. ## Dynamic Function tools Tools can optionally be defined with another function: `prepare`, which is called at each step of a run to customize the definition of the tool passed to the model, or omit the tool completely from that step. A `prepare` method can be registered via the `prepare` kwarg to any of the tool registration mechanisms: - @agent.tool decorator - @agent.tool_plain decorator - Tool dataclass The `prepare` method, should be of type ToolPrepareFunc, a function which takes RunContext and a pre-built ToolDefinition, and should either return that `ToolDefinition` with or without modifying it, return a new `ToolDefinition`, or return `None` to indicate this tools should not be registered for that step. Here's a simple `prepare` method that only includes the tool if the value of the dependency is `42`. As with the previous example, we use TestModel to demonstrate the behavior without calling a real model. tool_only_if_42.py ```python from typing import Union from pydantic_ai import Agent, RunContext from pydantic_ai.tools import ToolDefinition agent = Agent('test') async def only_if_42( ctx: RunContext[int], tool_def: ToolDefinition ) -> Union[ToolDefinition, None]: if ctx.deps == 42: return tool_def @agent.tool(prepare=only_if_42) def hitchhiker(ctx: RunContext[int], answer: str) -> str: return f'{ctx.deps} {answer}' result = agent.run_sync('testing...', deps=41) print(result.output) #> success (no tool calls) result = agent.run_sync('testing...', deps=42) print(result.output) #> {"hitchhiker":"42 a"} ``` *(This example is complete, it can be run "as is")* Here's a more complex example where we change the description of the `name` parameter to based on the value of `deps` For the sake of variation, we create this tool using the Tool dataclass. customize_name.py ```python from __future__ import annotations from typing import Literal from pydantic_ai import Agent, RunContext from pydantic_ai.models.test import TestModel from pydantic_ai.tools import Tool, ToolDefinition def greet(name: str) -> str: return f'hello {name}' async def prepare_greet( ctx: RunContext[Literal['human', 'machine']], tool_def: ToolDefinition ) -> ToolDefinition | None: d = f'Name of the {ctx.deps} to greet.' tool_def.parameters_json_schema['properties']['name']['description'] = d return tool_def greet_tool = Tool(greet, prepare=prepare_greet) test_model = TestModel() agent = Agent(test_model, tools=[greet_tool], deps_type=Literal['human', 'machine']) result = agent.run_sync('testing...', deps='human') print(result.output) #> {"greet":"hello a"} print(test_model.last_model_request_parameters.function_tools) """ [ ToolDefinition( name='greet', description='', parameters_json_schema={ 'additionalProperties': False, 'properties': { 'name': {'type': 'string', 'description': 'Name of the human to greet.'} }, 'required': ['name'], 'type': 'object', }, ) ] """ ``` *(This example is complete, it can be run "as is")* ## Agent-wide Dynamic Tool Preparation In addition to per-tool `prepare` methods, you can also define an agent-wide `prepare_tools` function. This function is called at each step of a run and allows you to filter or modify the list of all tool definitions available to the agent for that step. This is especially useful if you want to enable or disable multiple tools at once, or apply global logic based on the current context. The `prepare_tools` function should be of type ToolsPrepareFunc, which takes the RunContext and a list of ToolDefinition, and returns a new list of tool definitions (or `None` to disable all tools for that step). Note The list of tool definitions passed to `prepare_tools` includes both regular tools and tools from any MCP servers attached to the agent. Here's an example that makes all tools strict if the model is an OpenAI model: agent_prepare_tools_customize.py ```python from dataclasses import replace from typing import Union from pydantic_ai import Agent, RunContext from pydantic_ai.tools import ToolDefinition from pydantic_ai.models.test import TestModel async def turn_on_strict_if_openai( ctx: RunContext[None], tool_defs: list[ToolDefinition] ) -> Union[list[ToolDefinition], None]: if ctx.model.system == 'openai': return [replace(tool_def, strict=True) for tool_def in tool_defs] return tool_defs test_model = TestModel() agent = Agent(test_model, prepare_tools=turn_on_strict_if_openai) @agent.tool_plain def echo(message: str) -> str: return message agent.run_sync('testing...') assert test_model.last_model_request_parameters.function_tools[0].strict is None # Set the system attribute of the test_model to 'openai' test_model._system = 'openai' agent.run_sync('testing with openai...') assert test_model.last_model_request_parameters.function_tools[0].strict ``` *(This example is complete, it can be run "as is")* Here's another example that conditionally filters out the tools by name if the dependency (`ctx.deps`) is `True`: agent_prepare_tools_filter_out.py ```python from typing import Union from pydantic_ai import Agent, RunContext from pydantic_ai.tools import Tool, ToolDefinition def launch_potato(target: str) -> str: return f'Potato launched at {target}!' async def filter_out_tools_by_name( ctx: RunContext[bool], tool_defs: list[ToolDefinition] ) -> Union[list[ToolDefinition], None]: if ctx.deps: return [tool_def for tool_def in tool_defs if tool_def.name != 'launch_potato'] return tool_defs agent = Agent( 'test', tools=[Tool(launch_potato)], prepare_tools=filter_out_tools_by_name, deps_type=bool, ) result = agent.run_sync('testing...', deps=False) print(result.output) #> {"launch_potato":"Potato launched at a!"} result = agent.run_sync('testing...', deps=True) print(result.output) #> success (no tool calls) ``` *(This example is complete, it can be run "as is")* You can use `prepare_tools` to: - Dynamically enable or disable tools based on the current model, dependencies, or other context - Modify tool definitions globally (e.g., set all tools to strict mode, change descriptions, etc.) If both per-tool `prepare` and agent-wide `prepare_tools` are used, the per-tool `prepare` is applied first to each tool, and then `prepare_tools` is called with the resulting list of tool definitions. ## Tool Execution and Retries When a tool is executed, its arguments (provided by the LLM) are first validated against the function's signature using Pydantic. If validation fails (e.g., due to incorrect types or missing required arguments), a `ValidationError` is raised, and the framework automatically generates a RetryPromptPart containing the validation details. This prompt is sent back to the LLM, informing it of the error and allowing it to correct the parameters and retry the tool call. Beyond automatic validation errors, the tool's own internal logic can also explicitly request a retry by raising the ModelRetry exception. This is useful for situations where the parameters were technically valid, but an issue occurred during execution (like a transient network error, or the tool determining the initial attempt needs modification). ```python from pydantic_ai import ModelRetry def my_flaky_tool(query: str) -> str: if query == 'bad': # Tell the LLM the query was bad and it should try again raise ModelRetry("The query 'bad' is not allowed. Please provide a different query.") # ... process query ... return 'Success!' ``` Raising `ModelRetry` also generates a `RetryPromptPart` containing the exception message, which is sent back to the LLM to guide its next attempt. Both `ValidationError` and `ModelRetry` respect the `retries` setting configured on the `Tool` or `Agent`. ## Third-Party Tools ### MCP Tools See the [MCP Client](../mcp/client/) documentation for how to use MCP servers with Pydantic AI. ### LangChain Tools If you'd like to use a tool from LangChain's [community tool library](https://python.langchain.com/docs/integrations/tools/) with Pydantic AI, you can use the `pydancic_ai.ext.langchain.tool_from_langchain` convenience method. Note that Pydantic AI will not validate the arguments in this case -- it's up to the model to provide arguments matching the schema specified by the LangChain tool, and up to the LangChain tool to raise an error if the arguments are invalid. You will need to install the `langchain-community` package and any others required by the tool in question. Here is how you can use the LangChain `DuckDuckGoSearchRun` tool, which requires the `duckduckgo-search` package: ```python from langchain_community.tools import DuckDuckGoSearchRun from pydantic_ai import Agent from pydantic_ai.ext.langchain import tool_from_langchain search = DuckDuckGoSearchRun() search_tool = tool_from_langchain(search) agent = Agent( 'google-gla:gemini-2.0-flash', tools=[search_tool], ) result = agent.run_sync('What is the release date of Elden Ring Nightreign?') # (1)! print(result.output) #> Elden Ring Nightreign is planned to be released on May 30, 2025. ``` 1. The release date of this game is the 30th of May 2025, which is after the knowledge cutoff for Gemini 2.0 (August 2024). ### ACI.dev Tools If you'd like to use a tool from the [ACI.dev tool library](https://www.aci.dev/tools) with Pydantic AI, you can use the `pydancic_ai.ext.aci.tool_from_aci` convenience method. Note that Pydantic AI will not validate the arguments in this case -- it's up to the model to provide arguments matching the schema specified by the ACI tool, and up to the ACI tool to raise an error if the arguments are invalid. You will need to install the `aci-sdk` package, set your ACI API key in the `ACI_API_KEY` environment variable, and pass your ACI "linked account owner ID" to the function. Here is how you can use the ACI.dev `TAVILY__SEARCH` tool: ```python import os from pydantic_ai import Agent from pydantic_ai.ext.aci import tool_from_aci tavily_search = tool_from_aci( 'TAVILY__SEARCH', linked_account_owner_id=os.getenv('LINKED_ACCOUNT_OWNER_ID') ) agent = Agent( 'google-gla:gemini-2.0-flash', tools=[tavily_search] ) result = agent.run_sync('What is the release date of Elden Ring Nightreign?') # (1)! print(result.output) #> Elden Ring Nightreign is planned to be released on May 30, 2025. ``` 1. The release date of this game is the 30th of May 2025, which is after the knowledge cutoff for Gemini 2.0 (August 2024). # Models # Model Providers PydanticAI is model-agnostic and has built-in support for multiple model providers: - [OpenAI](openai/) - [Anthropic](anthropic/) - [Gemini](gemini/) (via two different APIs: Generative Language API and VertexAI API) - [Groq](groq/) - [Mistral](mistral/) - [Cohere](cohere/) - [Bedrock](bedrock/) ## OpenAI-compatible Providers In addition, many providers are compatible with the OpenAI API, and can be used with `OpenAIModel` in PydanticAI: - [DeepSeek](openai/#deepseek) - [Grok (xAI)](openai/#grok-xai) - [Ollama](openai/#ollama) - [OpenRouter](openai/#openrouter) - [Perplexity](openai/#perplexity) - [Fireworks AI](openai/#fireworks-ai) - [Together AI](openai/#together-ai) - [Azure AI Foundry](openai/#azure-ai-foundry) - [Heroku](openai/#heroku-ai) - [GitHub Models](openai/#github-models) PydanticAI also comes with [`TestModel`](../api/models/test/) and [`FunctionModel`](../api/models/function/) for testing and development. To use each model provider, you need to configure your local environment and make sure you have the right packages installed. ## Models and Providers PydanticAI uses a few key terms to describe how it interacts with different LLMs: - **Model**: This refers to the PydanticAI class used to make requests following a specific LLM API (generally by wrapping a vendor-provided SDK, like the `openai` python SDK). These classes implement a vendor-SDK-agnostic API, ensuring a single PydanticAI agent is portable to different LLM vendors without any other code changes just by swapping out the Model it uses. Model classes are named roughly in the format `Model`, for example, we have `OpenAIModel`, `AnthropicModel`, `GeminiModel`, etc. When using a Model class, you specify the actual LLM model name (e.g., `gpt-4o`, `claude-3-5-sonnet-latest`, `gemini-1.5-flash`) as a parameter. - **Provider**: This refers to provider-specific classes which handle the authentication and connections to an LLM vendor. Passing a non-default *Provider* as a parameter to a Model is how you can ensure that your agent will make requests to a specific endpoint, or make use of a specific approach to authentication (e.g., you can use Vertex-specific auth with the `GeminiModel` by way of the `VertexProvider`). In particular, this is how you can make use of an AI gateway, or an LLM vendor that offers API compatibility with the vendor SDK used by an existing Model (such as `OpenAIModel`). - **Profile**: This refers to a description of how requests to a specific model or family of models need to be constructed to get the best results, independent of the model and provider classes used. For example, different models have different restrictions on the JSON schemas that can be used for tools, and the same schema transformer needs to be used for Gemini models whether you're using `GoogleModel` with model name `gemini-2.5-pro-preview`, or `OpenAIModel` with `OpenRouterProvider` and model name `google/gemini-2.5-pro-preview`. When you instantiate an Agent with just a name formatted as `:`, e.g. `openai:gpt-4o` or `openrouter:google/gemini-2.5-pro-preview`, PydanticAI will automatically select the appropriate model class, provider, and profile. If you want to use a different provider or profile, you can instantiate a model class directly and pass in `provider` and/or `profile` arguments. ## Custom Models To implement support for a model API that's not already supported, you will need to subclass the Model abstract base class. For streaming, you'll also need to implement the StreamedResponse abstract base class. The best place to start is to review the source code for existing implementations, e.g. [`OpenAIModel`](https://github.com/pydantic/pydantic-ai/blob/main/pydantic_ai_slim/pydantic_ai/models/openai.py). For details on when we'll accept contributions adding new models to PydanticAI, see the [contributing guidelines](../contributing/#new-model-rules). If a model API is compatible with the OpenAI API, you do not need a custom model class and can provide your own [custom provider](openai/#openai-compatible-models) instead. ## Fallback Model You can use FallbackModel to attempt multiple models in sequence until one successfully returns a result. Under the hood, PydanticAI automatically switches from one model to the next if the current model returns a 4xx or 5xx status code. In the following example, the agent first makes a request to the OpenAI model (which fails due to an invalid API key), and then falls back to the Anthropic model. fallback_model.py ```python from pydantic_ai import Agent from pydantic_ai.models.anthropic import AnthropicModel from pydantic_ai.models.fallback import FallbackModel from pydantic_ai.models.openai import OpenAIModel openai_model = OpenAIModel('gpt-4o') anthropic_model = AnthropicModel('claude-3-5-sonnet-latest') fallback_model = FallbackModel(openai_model, anthropic_model) agent = Agent(fallback_model) response = agent.run_sync('What is the capital of France?') print(response.data) #> Paris print(response.all_messages()) """ [ ModelRequest( parts=[ UserPromptPart( content='What is the capital of France?', timestamp=datetime.datetime(...), part_kind='user-prompt', ) ], kind='request', ), ModelResponse( parts=[TextPart(content='Paris', part_kind='text')], model_name='claude-3-5-sonnet-latest', timestamp=datetime.datetime(...), kind='response', vendor_id=None, ), ] """ ``` The `ModelResponse` message above indicates in the `model_name` field that the output was returned by the Anthropic model, which is the second model specified in the `FallbackModel`. Note Each model's options should be configured individually. For example, `base_url`, `api_key`, and custom clients should be set on each model itself, not on the `FallbackModel`. In this next example, we demonstrate the exception-handling capabilities of `FallbackModel`. If all models fail, a FallbackExceptionGroup is raised, which contains all the exceptions encountered during the `run` execution. fallback_model_failure.py ```python from pydantic_ai import Agent from pydantic_ai.exceptions import ModelHTTPError from pydantic_ai.models.anthropic import AnthropicModel from pydantic_ai.models.fallback import FallbackModel from pydantic_ai.models.openai import OpenAIModel openai_model = OpenAIModel('gpt-4o') anthropic_model = AnthropicModel('claude-3-5-sonnet-latest') fallback_model = FallbackModel(openai_model, anthropic_model) agent = Agent(fallback_model) try: response = agent.run_sync('What is the capital of France?') except* ModelHTTPError as exc_group: for exc in exc_group.exceptions: print(exc) ``` Since [`except*`](https://docs.python.org/3/reference/compound_stmts.html#except-star) is only supported in Python 3.11+, we use the [`exceptiongroup`](https://github.com/agronholm/exceptiongroup) backport package for earlier Python versions: fallback_model_failure.py ```python from exceptiongroup import catch from pydantic_ai import Agent from pydantic_ai.exceptions import ModelHTTPError from pydantic_ai.models.anthropic import AnthropicModel from pydantic_ai.models.fallback import FallbackModel from pydantic_ai.models.openai import OpenAIModel def model_status_error_handler(exc_group: BaseExceptionGroup) -> None: for exc in exc_group.exceptions: print(exc) openai_model = OpenAIModel('gpt-4o') anthropic_model = AnthropicModel('claude-3-5-sonnet-latest') fallback_model = FallbackModel(openai_model, anthropic_model) agent = Agent(fallback_model) with catch({ModelHTTPError: model_status_error_handler}): response = agent.run_sync('What is the capital of France?') ``` By default, the `FallbackModel` only moves on to the next model if the current model raises a ModelHTTPError. You can customize this behavior by passing a custom `fallback_on` argument to the `FallbackModel` constructor. # Anthropic ## Install To use `AnthropicModel` models, you need to either install `pydantic-ai`, or install `pydantic-ai-slim` with the `anthropic` optional group: ```bash pip install "pydantic-ai-slim[anthropic]" ``` ```bash uv add "pydantic-ai-slim[anthropic]" ``` ## Configuration To use [Anthropic](https://anthropic.com) through their API, go to [console.anthropic.com/settings/keys](https://console.anthropic.com/settings/keys) to generate an API key. `AnthropicModelName` contains a list of available Anthropic models. ## Environment variable Once you have the API key, you can set it as an environment variable: ```bash export ANTHROPIC_API_KEY='your-api-key' ``` You can then use `AnthropicModel` by name: ```python from pydantic_ai import Agent agent = Agent('anthropic:claude-3-5-sonnet-latest') ... ``` Or initialise the model directly with just the model name: ```python from pydantic_ai import Agent from pydantic_ai.models.anthropic import AnthropicModel model = AnthropicModel('claude-3-5-sonnet-latest') agent = Agent(model) ... ``` ## `provider` argument You can provide a custom `Provider` via the `provider` argument: ```python from pydantic_ai import Agent from pydantic_ai.models.anthropic import AnthropicModel from pydantic_ai.providers.anthropic import AnthropicProvider model = AnthropicModel( 'claude-3-5-sonnet-latest', provider=AnthropicProvider(api_key='your-api-key') ) agent = Agent(model) ... ``` ## Custom HTTP Client You can customize the `AnthropicProvider` with a custom `httpx.AsyncClient`: ```python from httpx import AsyncClient from pydantic_ai import Agent from pydantic_ai.models.anthropic import AnthropicModel from pydantic_ai.providers.anthropic import AnthropicProvider custom_http_client = AsyncClient(timeout=30) model = AnthropicModel( 'claude-3-5-sonnet-latest', provider=AnthropicProvider(api_key='your-api-key', http_client=custom_http_client), ) agent = Agent(model) ... ``` # Bedrock ## Install To use `BedrockConverseModel`, you need to either install `pydantic-ai`, or install `pydantic-ai-slim` with the `bedrock` optional group: ```bash pip install "pydantic-ai-slim[bedrock]" ``` ```bash uv add "pydantic-ai-slim[bedrock]" ``` ## Configuration To use [AWS Bedrock](https://aws.amazon.com/bedrock/), you'll need an AWS account with Bedrock enabled and appropriate credentials. You can use either AWS credentials directly or a pre-configured boto3 client. `BedrockModelName` contains a list of available Bedrock models, including models from Anthropic, Amazon, Cohere, Meta, and Mistral. ## Environment variables You can set your AWS credentials as environment variables ([among other options](https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html#using-environment-variables)): ```bash export AWS_ACCESS_KEY_ID='your-access-key' export AWS_SECRET_ACCESS_KEY='your-secret-key' export AWS_DEFAULT_REGION='us-east-1' # or your preferred region ``` You can then use `BedrockConverseModel` by name: ```python from pydantic_ai import Agent agent = Agent('bedrock:anthropic.claude-3-sonnet-20240229-v1:0') ... ``` Or initialize the model directly with just the model name: ```python from pydantic_ai import Agent from pydantic_ai.models.bedrock import BedrockConverseModel model = BedrockConverseModel('anthropic.claude-3-sonnet-20240229-v1:0') agent = Agent(model) ... ``` ## Customizing Bedrock Runtime API You can customize the Bedrock Runtime API calls by adding additional parameters, such as [guardrail configurations](https://docs.aws.amazon.com/bedrock/latest/userguide/guardrails.html) and [performance settings](https://docs.aws.amazon.com/bedrock/latest/userguide/latency-optimized-inference.html). For a complete list of configurable parameters, refer to the documentation for BedrockModelSettings. customize_bedrock_model_settings.py ```python from pydantic_ai import Agent from pydantic_ai.models.bedrock import BedrockConverseModel, BedrockModelSettings # Define Bedrock model settings with guardrail and performance configurations bedrock_model_settings = BedrockModelSettings( bedrock_guardrail_config={ 'guardrailIdentifier': 'v1', 'guardrailVersion': 'v1', 'trace': 'enabled' }, bedrock_performance_configuration={ 'latency': 'optimized' } ) model = BedrockConverseModel(model_name='us.amazon.nova-pro-v1:0') agent = Agent(model=model, model_settings=bedrock_model_settings) ``` ## `provider` argument You can provide a custom `BedrockProvider` via the `provider` argument. This is useful when you want to specify credentials directly or use a custom boto3 client: ```python from pydantic_ai import Agent from pydantic_ai.models.bedrock import BedrockConverseModel from pydantic_ai.providers.bedrock import BedrockProvider # Using AWS credentials directly model = BedrockConverseModel( 'anthropic.claude-3-sonnet-20240229-v1:0', provider=BedrockProvider( region_name='us-east-1', aws_access_key_id='your-access-key', aws_secret_access_key='your-secret-key', ), ) agent = Agent(model) ... ``` You can also pass a pre-configured boto3 client: ```python import boto3 from pydantic_ai import Agent from pydantic_ai.models.bedrock import BedrockConverseModel from pydantic_ai.providers.bedrock import BedrockProvider # Using a pre-configured boto3 client bedrock_client = boto3.client('bedrock-runtime', region_name='us-east-1') model = BedrockConverseModel( 'anthropic.claude-3-sonnet-20240229-v1:0', provider=BedrockProvider(bedrock_client=bedrock_client), ) agent = Agent(model) ... ``` # Cohere ## Install To use `CohereModel`, you need to either install `pydantic-ai`, or install `pydantic-ai-slim` with the `cohere` optional group: ```bash pip install "pydantic-ai-slim[cohere]" ``` ```bash uv add "pydantic-ai-slim[cohere]" ``` ## Configuration To use [Cohere](https://cohere.com/) through their API, go to [dashboard.cohere.com/api-keys](https://dashboard.cohere.com/api-keys) and follow your nose until you find the place to generate an API key. `CohereModelName` contains a list of the most popular Cohere models. ## Environment variable Once you have the API key, you can set it as an environment variable: ```bash export CO_API_KEY='your-api-key' ``` You can then use `CohereModel` by name: ```python from pydantic_ai import Agent agent = Agent('cohere:command') ... ``` Or initialise the model directly with just the model name: ```python from pydantic_ai import Agent from pydantic_ai.models.cohere import CohereModel model = CohereModel('command') agent = Agent(model) ... ``` ## `provider` argument You can provide a custom `Provider` via the `provider` argument: ```python from pydantic_ai import Agent from pydantic_ai.models.cohere import CohereModel from pydantic_ai.providers.cohere import CohereProvider model = CohereModel('command', provider=CohereProvider(api_key='your-api-key')) agent = Agent(model) ... ``` You can also customize the `CohereProvider` with a custom `http_client`: ```python from httpx import AsyncClient from pydantic_ai import Agent from pydantic_ai.models.cohere import CohereModel from pydantic_ai.providers.cohere import CohereProvider custom_http_client = AsyncClient(timeout=30) model = CohereModel( 'command', provider=CohereProvider(api_key='your-api-key', http_client=custom_http_client), ) agent = Agent(model) ... ``` # Gemini Note We've developed a new Google model called `GoogleModel` which uses `google-genai` under the hood. Honestly, Google packages are a mess, and that's why we've used plain `httpx` instead of relying on their own client to create `GeminiModel`. That said, it's easier to use the `google-genai` package directly, since they keep the package up-to-date with the latest API changes. For that reason, we've created a new model called `GoogleModel` which uses `google-genai` under the hood. Check it out [here](../../api/models/google/). PydanticAI supports Google's Gemini models through two different APIs: - Generative Language API (`generativelanguage.googleapis.com`) - Vertex AI API (`*-aiplatform.googleapis.com`) ## Gemini via Generative Language API ### Install To use `GeminiModel` models, you just need to install `pydantic-ai` or `pydantic-ai-slim`, no extra dependencies are required. ### Configuration `GeminiModel` lets you use Google's Gemini models through their [Generative Language API](https://ai.google.dev/api/all-methods), `generativelanguage.googleapis.com`. `GeminiModelName` contains a list of available Gemini models that can be used through this interface. To use `GeminiModel`, go to [aistudio.google.com](https://aistudio.google.com/apikey) and select "Create API key". ### Environment variable Once you have the API key, you can set it as an environment variable: ```bash export GEMINI_API_KEY=your-api-key ``` You can then use `GeminiModel` by name: ```python from pydantic_ai import Agent agent = Agent('google-gla:gemini-2.0-flash') ... ``` Note The `google-gla` provider prefix represents the [Google **G**enerative **L**anguage **A**PI](https://ai.google.dev/api/all-methods) for `GeminiModel`s. `google-vertex` is used with [Vertex AI](https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models). Or initialise the model directly with just the model name and provider: ```python from pydantic_ai import Agent from pydantic_ai.models.gemini import GeminiModel model = GeminiModel('gemini-2.0-flash', provider='google-gla') agent = Agent(model) ... ``` ### `provider` argument You can provide a custom `Provider` via the `provider` argument: ```python from pydantic_ai import Agent from pydantic_ai.models.gemini import GeminiModel from pydantic_ai.providers.google_gla import GoogleGLAProvider model = GeminiModel( 'gemini-2.0-flash', provider=GoogleGLAProvider(api_key='your-api-key') ) agent = Agent(model) ... ``` You can also customize the `GoogleGLAProvider` with a custom `http_client`: ```python from httpx import AsyncClient from pydantic_ai import Agent from pydantic_ai.models.gemini import GeminiModel from pydantic_ai.providers.google_gla import GoogleGLAProvider custom_http_client = AsyncClient(timeout=30) model = GeminiModel( 'gemini-2.0-flash', provider=GoogleGLAProvider(api_key='your-api-key', http_client=custom_http_client), ) agent = Agent(model) ... ``` ## Gemini via VertexAI If you are an enterprise user, you should use the `google-vertex` provider with `GeminiModel` which uses the `*-aiplatform.googleapis.com` API. `GeminiModelName` contains a list of available Gemini models that can be used through this interface. ### Install To use the `google-vertex` provider with `GeminiModel`, you need to either install `pydantic-ai`, or install `pydantic-ai-slim` with the `vertexai` optional group: ```bash pip install "pydantic-ai-slim[vertexai]" ``` ```bash uv add "pydantic-ai-slim[vertexai]" ``` ### Configuration This interface has a number of advantages over `generativelanguage.googleapis.com` documented above: 1. The VertexAI API comes with more enterprise readiness guarantees. 1. You can [purchase provisioned throughput](https://cloud.google.com/vertex-ai/generative-ai/docs/provisioned-throughput#purchase-provisioned-throughput) with VertexAI to guarantee capacity. 1. If you're running PydanticAI inside GCP, you don't need to set up authentication, it should "just work". 1. You can decide which region to use, which might be important from a regulatory perspective, and might improve latency. The big disadvantage is that for local development you may need to create and configure a "service account", which can be challenging to get right. Whichever way you authenticate, you'll need to have VertexAI enabled in your GCP account. ### Application default credentials Luckily if you're running PydanticAI inside GCP, or you have the [`gcloud` CLI](https://cloud.google.com/sdk/gcloud) installed and configured, you should be able to use `VertexAIModel` without any additional setup. To use `VertexAIModel`, with [application default credentials](https://cloud.google.com/docs/authentication/application-default-credentials) configured (e.g. with `gcloud`), you can simply use: ```python from pydantic_ai import Agent from pydantic_ai.models.gemini import GeminiModel model = GeminiModel('gemini-2.0-flash', provider='google-vertex') agent = Agent(model) ... ``` Internally this uses [`google.auth.default()`](https://google-auth.readthedocs.io/en/master/reference/google.auth.html) from the `google-auth` package to obtain credentials. Won't fail until `agent.run()` Because `google.auth.default()` requires network requests and can be slow, it's not run until you call `agent.run()`. You may also need to pass the `project_id` argument to `GoogleVertexProvider` if application default credentials don't set a project, if you pass `project_id` and it conflicts with the project set by application default credentials, an error is raised. ### Service account If instead of application default credentials, you want to authenticate with a service account, you'll need to create a service account, add it to your GCP project (note: this step is necessary even if you created the service account within the project), give that service account the "Vertex AI Service Agent" role, and download the service account JSON file. Once you have the JSON file, you can use it thus: ```python from pydantic_ai import Agent from pydantic_ai.models.gemini import GeminiModel from pydantic_ai.providers.google_vertex import GoogleVertexProvider model = GeminiModel( 'gemini-2.0-flash', provider=GoogleVertexProvider(service_account_file='path/to/service-account.json'), ) agent = Agent(model) ... ``` Alternatively, if you already have the service account information in memory, you can pass it as a dictionary: ```python import json from pydantic_ai import Agent from pydantic_ai.models.gemini import GeminiModel from pydantic_ai.providers.google_vertex import GoogleVertexProvider service_account_info = json.loads( '{"type": "service_account", "project_id": "my-project-id"}' ) model = GeminiModel( 'gemini-2.0-flash', provider=GoogleVertexProvider(service_account_info=service_account_info), ) agent = Agent(model) ... ``` ### Customizing region Whichever way you authenticate, you can specify which region requests will be sent to via the `region` argument. Using a region close to your application can improve latency and might be important from a regulatory perspective. ```python from pydantic_ai import Agent from pydantic_ai.models.gemini import GeminiModel from pydantic_ai.providers.google_vertex import GoogleVertexProvider model = GeminiModel( 'gemini-2.0-flash', provider=GoogleVertexProvider(region='asia-east1') ) agent = Agent(model) ... ``` You can also customize the `GoogleVertexProvider` with a custom `http_client`: ```python from httpx import AsyncClient from pydantic_ai import Agent from pydantic_ai.models.gemini import GeminiModel from pydantic_ai.providers.google_vertex import GoogleVertexProvider custom_http_client = AsyncClient(timeout=30) model = GeminiModel( 'gemini-2.0-flash', provider=GoogleVertexProvider(region='asia-east1', http_client=custom_http_client), ) agent = Agent(model) ... ``` ### Model settings You can use the GeminiModelSettings class to customize the model request. #### Disable thinking You can disable thinking by setting the `thinking_budget` to `0` on the `google_thinking_config`: ```python from pydantic_ai import Agent from pydantic_ai.models.gemini import GeminiModel, GeminiModelSettings model_settings = GeminiModelSettings(gemini_thinking_config={'thinking_budget': 0}) model = GeminiModel('gemini-2.0-flash') agent = Agent(model, model_settings=model_settings) ... ``` Check out the [Gemini API docs](https://ai.google.dev/gemini-api/docs/thinking) for more on thinking. #### Safety settings You can customize the safety settings by setting the `google_safety_settings` field. ```python from pydantic_ai import Agent from pydantic_ai.models.gemini import GeminiModel, GeminiModelSettings model_settings = GeminiModelSettings( gemini_safety_settings=[ { 'category': 'HARM_CATEGORY_DANGEROUS_CONTENT', 'threshold': 'BLOCK_ONLY_HIGH', } ] ) model = GeminiModel('gemini-2.0-flash') agent = Agent(model, model_settings=model_settings) ... ``` Check out the [Gemini API docs](https://ai.google.dev/gemini-api/docs/safety-settings) for more on safety settings. # Google The `GoogleModel` is a model that uses the [`google-genai`](https://pypi.org/project/google-genai/) package under the hood to access Google's Gemini models via both the Generative Language API and Vertex AI. ## Install To use `GoogleModel`, you need to either install `pydantic-ai`, or install `pydantic-ai-slim` with the `google` optional group: ```bash pip install "pydantic-ai-slim[google]" ``` ```bash uv add "pydantic-ai-slim[google]" ``` ______________________________________________________________________ Explicit instantiation required You **cannot** currently use `Agent('google-gla:gemini-1.5-flash')` or `Agent('google-vertex:gemini-1.5-flash')` directly with `GoogleModel`. The model name inference will select [`GeminiModel`](../gemini/) instead of `GoogleModel`. To use `GoogleModel`, you **must** explicitly instantiate a GoogleProvider and pass it to GoogleModel, then pass the model to Agent. ______________________________________________________________________ ## Configuration `GoogleModel` lets you use Google's Gemini models through their [Generative Language API](https://ai.google.dev/api/all-methods) (`generativelanguage.googleapis.com`) or [Vertex AI API](https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models) (`*-aiplatform.googleapis.com`). ### API Key (Generative Language API) To use Gemini via the Generative Language API, go to [aistudio.google.com](https://aistudio.google.com/apikey) and create an API key. Once you have the API key, set it as an environment variable: ```bash export GOOGLE_API_KEY=your-api-key ``` You can then use `GoogleModel` by explicitly creating a provider: ```python from pydantic_ai import Agent from pydantic_ai.models.google import GoogleModel from pydantic_ai.providers.google import GoogleProvider provider = GoogleProvider(api_key='your-api-key') model = GoogleModel('gemini-1.5-flash', provider=provider) agent = Agent(model) ... ``` ### Vertex AI (Enterprise/Cloud) If you are an enterprise user, you can use the `google-vertex` provider with `GoogleModel` to access Gemini via Vertex AI. To use Vertex AI, you may need to set up [application default credentials](https://cloud.google.com/docs/authentication/application-default-credentials) or use a service account. You can also specify the region. #### Application Default Credentials If you have the [`gcloud` CLI](https://cloud.google.com/sdk/gcloud) installed and configured, you can use: ```python from pydantic_ai import Agent from pydantic_ai.models.google import GoogleModel from pydantic_ai.providers.google import GoogleProvider provider = GoogleProvider(vertexai=True) model = GoogleModel('gemini-1.5-flash', provider=provider) agent = Agent(model) ... ``` #### Service Account To use a service account JSON file: google_model_service_account.py ```python from google.oauth2 import service_account from pydantic_ai import Agent from pydantic_ai.models.google import GoogleModel from pydantic_ai.providers.google import GoogleProvider credentials = service_account.Credentials.from_service_account_file( 'path/to/service-account.json', scopes=['https://www.googleapis.com/auth/cloud-platform'], ) provider = GoogleProvider(credentials=credentials) model = GoogleModel('gemini-1.5-flash', provider=provider) agent = Agent(model) ... ``` #### Customizing Location You can specify the location when using Vertex AI: google_model_location.py ```python from pydantic_ai import Agent from pydantic_ai.models.google import GoogleModel from pydantic_ai.providers.google import GoogleProvider provider = GoogleProvider(vertexai=True, location='asia-east1') model = GoogleModel('gemini-1.5-flash', provider=provider) agent = Agent(model) ... ``` ## Provider Argument You can supply a custom `GoogleProvider` instance using the `provider` argument to configure advanced client options, such as setting a custom `base_url`. This is useful if you're using a custom-compatible endpoint with the Google Generative Language API. ```python from google import genai from google.genai.types import HttpOptions from pydantic_ai import Agent from pydantic_ai.models.google import GoogleModel from pydantic_ai.providers.google import GoogleProvider client = genai.Client( api_key='gemini-custom-api-key', http_options=HttpOptions(base_url='gemini-custom-base-url'), ) provider = GoogleProvider(client=client) model = GoogleModel('gemini-1.5-flash', provider=provider) agent = Agent(model) ... ``` ## Model Settings You can customize model behavior using GoogleModelSettings: ```python from google.genai.types import HarmBlockThreshold, HarmCategory from pydantic_ai import Agent from pydantic_ai.models.google import GoogleModel, GoogleModelSettings settings = GoogleModelSettings( temperature=0.2, max_tokens=1024, google_thinking_config={'thinking_budget': 2048}, google_safety_settings=[ { 'category': HarmCategory.HARM_CATEGORY_HATE_SPEECH, 'threshold': HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, } ] ) model = GoogleModel('gemini-1.5-flash') agent = Agent(model, model_settings=settings) ... ``` See the [Gemini API docs](https://ai.google.dev/gemini-api/docs/safety-settings) for more on safety settings, and [thinking config](https://ai.google.dev/gemini-api/docs/thinking). ## Document, Image, Audio, and Video Input `GoogleModel` supports multi-modal input, including documents, images, audio, and video. See the [input documentation](../../input/) for details and examples. ## Model settings You can use the GoogleModelSettings class to customize the model request. ### Disable thinking You can disable thinking by setting the `thinking_budget` to `0` on the `google_thinking_config`: ```python from pydantic_ai import Agent from pydantic_ai.models.google import GoogleModel, GoogleModelSettings model_settings = GoogleModelSettings(google_thinking_config={'thinking_budget': 0}) model = GoogleModel('gemini-2.0-flash') agent = Agent(model, model_settings=model_settings) ... ``` Check out the [Gemini API docs](https://ai.google.dev/gemini-api/docs/thinking) for more on thinking. ### Safety settings You can customize the safety settings by setting the `google_safety_settings` field. ```python from google.genai.types import HarmBlockThreshold, HarmCategory from pydantic_ai import Agent from pydantic_ai.models.google import GoogleModel, GoogleModelSettings model_settings = GoogleModelSettings( google_safety_settings=[ { 'category': HarmCategory.HARM_CATEGORY_HATE_SPEECH, 'threshold': HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, } ] ) model = GoogleModel('gemini-2.0-flash') agent = Agent(model, model_settings=model_settings) ... ``` See the [Gemini API docs](https://ai.google.dev/gemini-api/docs/safety-settings) for more on safety settings. # Groq ## Install To use `GroqModel`, you need to either install `pydantic-ai`, or install `pydantic-ai-slim` with the `groq` optional group: ```bash pip install "pydantic-ai-slim[groq]" ``` ```bash uv add "pydantic-ai-slim[groq]" ``` ## Configuration To use [Groq](https://groq.com/) through their API, go to [console.groq.com/keys](https://console.groq.com/keys) and follow your nose until you find the place to generate an API key. `GroqModelName` contains a list of available Groq models. ## Environment variable Once you have the API key, you can set it as an environment variable: ```bash export GROQ_API_KEY='your-api-key' ``` You can then use `GroqModel` by name: ```python from pydantic_ai import Agent agent = Agent('groq:llama-3.3-70b-versatile') ... ``` Or initialise the model directly with just the model name: ```python from pydantic_ai import Agent from pydantic_ai.models.groq import GroqModel model = GroqModel('llama-3.3-70b-versatile') agent = Agent(model) ... ``` ## `provider` argument You can provide a custom `Provider` via the `provider` argument: ```python from pydantic_ai import Agent from pydantic_ai.models.groq import GroqModel from pydantic_ai.providers.groq import GroqProvider model = GroqModel( 'llama-3.3-70b-versatile', provider=GroqProvider(api_key='your-api-key') ) agent = Agent(model) ... ``` You can also customize the `GroqProvider` with a custom `httpx.AsyncHTTPClient`: ```python from httpx import AsyncClient from pydantic_ai import Agent from pydantic_ai.models.groq import GroqModel from pydantic_ai.providers.groq import GroqProvider custom_http_client = AsyncClient(timeout=30) model = GroqModel( 'llama-3.3-70b-versatile', provider=GroqProvider(api_key='your-api-key', http_client=custom_http_client), ) agent = Agent(model) ... ``` # Mistral ## Install To use `MistralModel`, you need to either install `pydantic-ai`, or install `pydantic-ai-slim` with the `mistral` optional group: ```bash pip install "pydantic-ai-slim[mistral]" ``` ```bash uv add "pydantic-ai-slim[mistral]" ``` ## Configuration To use [Mistral](https://mistral.ai) through their API, go to [console.mistral.ai/api-keys/](https://console.mistral.ai/api-keys/) and follow your nose until you find the place to generate an API key. `LatestMistralModelNames` contains a list of the most popular Mistral models. ## Environment variable Once you have the API key, you can set it as an environment variable: ```bash export MISTRAL_API_KEY='your-api-key' ``` You can then use `MistralModel` by name: ```python from pydantic_ai import Agent agent = Agent('mistral:mistral-large-latest') ... ``` Or initialise the model directly with just the model name: ```python from pydantic_ai import Agent from pydantic_ai.models.mistral import MistralModel model = MistralModel('mistral-small-latest') agent = Agent(model) ... ``` ## `provider` argument You can provide a custom `Provider` via the `provider` argument: ```python from pydantic_ai import Agent from pydantic_ai.models.mistral import MistralModel from pydantic_ai.providers.mistral import MistralProvider model = MistralModel( 'mistral-large-latest', provider=MistralProvider(api_key='your-api-key', base_url='https://') ) agent = Agent(model) ... ``` You can also customize the provider with a custom `httpx.AsyncHTTPClient`: ```python from httpx import AsyncClient from pydantic_ai import Agent from pydantic_ai.models.mistral import MistralModel from pydantic_ai.providers.mistral import MistralProvider custom_http_client = AsyncClient(timeout=30) model = MistralModel( 'mistral-large-latest', provider=MistralProvider(api_key='your-api-key', http_client=custom_http_client), ) agent = Agent(model) ... ``` # OpenAI ## Install To use OpenAI models or OpenAI-compatible APIs, you need to either install `pydantic-ai`, or install `pydantic-ai-slim` with the `openai` optional group: ```bash pip install "pydantic-ai-slim[openai]" ``` ```bash uv add "pydantic-ai-slim[openai]" ``` ## Configuration To use `OpenAIModel` with the OpenAI API, go to [platform.openai.com](https://platform.openai.com/) and follow your nose until you find the place to generate an API key. ## Environment variable Once you have the API key, you can set it as an environment variable: ```bash export OPENAI_API_KEY='your-api-key' ``` You can then use `OpenAIModel` by name: ```python from pydantic_ai import Agent agent = Agent('openai:gpt-4o') ... ``` Or initialise the model directly with just the model name: ```python from pydantic_ai import Agent from pydantic_ai.models.openai import OpenAIModel model = OpenAIModel('gpt-4o') agent = Agent(model) ... ``` By default, the `OpenAIModel` uses the `OpenAIProvider` with the `base_url` set to `https://api.openai.com/v1`. ## Configure the provider If you want to pass parameters in code to the provider, you can programmatically instantiate the OpenAIProvider and pass it to the model: ```python from pydantic_ai import Agent from pydantic_ai.models.openai import OpenAIModel from pydantic_ai.providers.openai import OpenAIProvider model = OpenAIModel('gpt-4o', provider=OpenAIProvider(api_key='your-api-key')) agent = Agent(model) ... ``` ## Custom OpenAI Client `OpenAIProvider` also accepts a custom `AsyncOpenAI` client via the `openai_client` parameter, so you can customise the `organization`, `project`, `base_url` etc. as defined in the [OpenAI API docs](https://platform.openai.com/docs/api-reference). custom_openai_client.py ```python from openai import AsyncOpenAI from pydantic_ai import Agent from pydantic_ai.models.openai import OpenAIModel from pydantic_ai.providers.openai import OpenAIProvider client = AsyncOpenAI(max_retries=3) model = OpenAIModel('gpt-4o', provider=OpenAIProvider(openai_client=client)) agent = Agent(model) ... ``` You could also use the [`AsyncAzureOpenAI`](https://learn.microsoft.com/en-us/azure/ai-services/openai/how-to/switching-endpoints) client to use the Azure OpenAI API. Note that the `AsyncAzureOpenAI` is a subclass of `AsyncOpenAI`. ```python from openai import AsyncAzureOpenAI from pydantic_ai import Agent from pydantic_ai.models.openai import OpenAIModel from pydantic_ai.providers.openai import OpenAIProvider client = AsyncAzureOpenAI( azure_endpoint='...', api_version='2024-07-01-preview', api_key='your-api-key', ) model = OpenAIModel( 'gpt-4o', provider=OpenAIProvider(openai_client=client), ) agent = Agent(model) ... ``` ## OpenAI Responses API PydanticAI also supports OpenAI's [Responses API](https://platform.openai.com/docs/api-reference/responses) through the `OpenAIResponsesModel` class. ```python from pydantic_ai import Agent from pydantic_ai.models.openai import OpenAIResponsesModel model = OpenAIResponsesModel('gpt-4o') agent = Agent(model) ... ``` The Responses API has built-in tools that you can use instead of building your own: - [Web search](https://platform.openai.com/docs/guides/tools-web-search): allow models to search the web for the latest information before generating a response. - [File search](https://platform.openai.com/docs/guides/tools-file-search): allow models to search your files for relevant information before generating a response. - [Computer use](https://platform.openai.com/docs/guides/tools-computer-use): allow models to use a computer to perform tasks on your behalf. You can use the `OpenAIResponsesModelSettings` class to make use of those built-in tools: ```python from openai.types.responses import WebSearchToolParam # (1)! from pydantic_ai import Agent from pydantic_ai.models.openai import OpenAIResponsesModel, OpenAIResponsesModelSettings model_settings = OpenAIResponsesModelSettings( openai_builtin_tools=[WebSearchToolParam(type='web_search_preview')], ) model = OpenAIResponsesModel('gpt-4o') agent = Agent(model=model, model_settings=model_settings) result = agent.run_sync('What is the weather in Tokyo?') print(result.output) """ As of 7:48 AM on Wednesday, April 2, 2025, in Tokyo, Japan, the weather is cloudy with a temperature of 53°F (12°C). """ ``` 1. The file search tool and computer use tool can also be imported from `openai.types.responses`. You can learn more about the differences between the Responses API and Chat Completions API in the [OpenAI API docs](https://platform.openai.com/docs/guides/responses-vs-chat-completions). ## OpenAI-compatible Models Many providers and models are compatible with the OpenAI API, and can be used with `OpenAIModel` in PydanticAI. Before getting started, check the [installation and configuration](#install) instructions above. To use another OpenAI-compatible API, you can make use of the `base_url` and `api_key` arguments from `OpenAIProvider`: ```python from pydantic_ai import Agent from pydantic_ai.models.openai import OpenAIModel from pydantic_ai.providers.openai import OpenAIProvider model = OpenAIModel( 'model_name', provider=OpenAIProvider( base_url='https://.com', api_key='your-api-key' ), ) agent = Agent(model) ... ``` Various providers also have their own provider classes so that you don't need to specify the base URL yourself and you can use the standard `_API_KEY` environment variable to set the API key. When a provider has its own provider class, you can use the `Agent(":")` shorthand, e.g. `Agent("deepseek:deepseek-chat")` or `Agent("openrouter:google/gemini-2.5-pro-preview")`, instead of building the `OpenAIModel` explicitly. Similarly, you can pass the provider name as a string to the `provider` argument on `OpenAIModel` instead of building instantiating the provider class explicitly. #### Model Profile Sometimes, the provider or model you're using will have slightly different requirements than OpenAI's API or models, like having different restrictions on JSON schemas for tool definitions, or not supporting tool definitions to be marked as strict. When using an alternative provider class provided by PydanticAI, an appropriate model profile is typically selected automatically based on the model name. If the model you're using is not working correctly out of the box, you can tweak various aspects of how model requests are constructed by providing your own ModelProfile (for behaviors shared among all model classes) or OpenAIModelProfile (for behaviors specific to `OpenAIModel`): ```py from pydantic_ai import Agent from pydantic_ai.models.openai import OpenAIModel from pydantic_ai.profiles._json_schema import InlineDefsJsonSchemaTransformer from pydantic_ai.profiles.openai import OpenAIModelProfile from pydantic_ai.providers.openai import OpenAIProvider model = OpenAIModel( 'model_name', provider=OpenAIProvider( base_url='https://.com', api_key='your-api-key' ), profile=OpenAIModelProfile( json_schema_transformer=InlineDefsJsonSchemaTransformer, # Supported by any model class on a plain ModelProfile openai_supports_strict_tool_definition=False # Supported by OpenAIModel only, requires OpenAIModelProfile ) ) agent = Agent(model) ``` ### DeepSeek To use the [DeepSeek](https://deepseek.com) provider, first create an API key by following the [Quick Start guide](https://api-docs.deepseek.com/). Once you have the API key, you can use it with the DeepSeekProvider: ```python from pydantic_ai import Agent from pydantic_ai.models.openai import OpenAIModel from pydantic_ai.providers.deepseek import DeepSeekProvider model = OpenAIModel( 'deepseek-chat', provider=DeepSeekProvider(api_key='your-deepseek-api-key'), ) agent = Agent(model) ... ``` You can also customize any provider with a custom `http_client`: ```python from httpx import AsyncClient from pydantic_ai import Agent from pydantic_ai.models.openai import OpenAIModel from pydantic_ai.providers.deepseek import DeepSeekProvider custom_http_client = AsyncClient(timeout=30) model = OpenAIModel( 'deepseek-chat', provider=DeepSeekProvider( api_key='your-deepseek-api-key', http_client=custom_http_client ), ) agent = Agent(model) ... ``` ### Ollama To use [Ollama](https://ollama.com/), you must first download the Ollama client, and then download a model using the [Ollama model library](https://ollama.com/library). You must also ensure the Ollama server is running when trying to make requests to it. For more information, please see the [Ollama documentation](https://github.com/ollama/ollama/tree/main/docs). #### Example local usage With `ollama` installed, you can run the server with the model you want to use: ```bash ollama run llama3.2 ``` (this will pull the `llama3.2` model if you don't already have it downloaded) Then run your code, here's a minimal example: ```python from pydantic import BaseModel from pydantic_ai import Agent from pydantic_ai.models.openai import OpenAIModel from pydantic_ai.providers.openai import OpenAIProvider class CityLocation(BaseModel): city: str country: str ollama_model = OpenAIModel( model_name='llama3.2', provider=OpenAIProvider(base_url='http://localhost:11434/v1') ) agent = Agent(ollama_model, output_type=CityLocation) result = agent.run_sync('Where were the olympics held in 2012?') print(result.output) #> city='London' country='United Kingdom' print(result.usage()) #> Usage(requests=1, request_tokens=57, response_tokens=8, total_tokens=65) ``` #### Example using a remote server ```python from pydantic import BaseModel from pydantic_ai import Agent from pydantic_ai.models.openai import OpenAIModel from pydantic_ai.providers.openai import OpenAIProvider ollama_model = OpenAIModel( model_name='qwen2.5-coder:7b', # (1)! provider=OpenAIProvider(base_url='http://192.168.1.74:11434/v1'), # (2)! ) class CityLocation(BaseModel): city: str country: str agent = Agent(model=ollama_model, output_type=CityLocation) result = agent.run_sync('Where were the olympics held in 2012?') print(result.output) #> city='London' country='United Kingdom' print(result.usage()) #> Usage(requests=1, request_tokens=57, response_tokens=8, total_tokens=65) ``` 1. The name of the model running on the remote server 1. The url of the remote server ### Azure AI Foundry If you want to use [Azure AI Foundry](https://ai.azure.com/) as your provider, you can do so by using the AzureProvider class. ```python from pydantic_ai import Agent from pydantic_ai.models.openai import OpenAIModel from pydantic_ai.providers.azure import AzureProvider model = OpenAIModel( 'gpt-4o', provider=AzureProvider( azure_endpoint='your-azure-endpoint', api_version='your-api-version', api_key='your-api-key', ), ) agent = Agent(model) ... ``` ### OpenRouter To use [OpenRouter](https://openrouter.ai), first create an API key at [openrouter.ai/keys](https://openrouter.ai/keys). Once you have the API key, you can use it with the OpenRouterProvider: ```python from pydantic_ai import Agent from pydantic_ai.models.openai import OpenAIModel from pydantic_ai.providers.openrouter import OpenRouterProvider model = OpenAIModel( 'anthropic/claude-3.5-sonnet', provider=OpenRouterProvider(api_key='your-openrouter-api-key'), ) agent = Agent(model) ... ``` ### Grok (xAI) Go to [xAI API Console](https://console.x.ai/) and create an API key. Once you have the API key, you can use it with the GrokProvider: ```python from pydantic_ai import Agent from pydantic_ai.models.openai import OpenAIModel from pydantic_ai.providers.grok import GrokProvider model = OpenAIModel( 'grok-2-1212', provider=GrokProvider(api_key='your-xai-api-key'), ) agent = Agent(model) ... ``` ### GitHub Models To use [GitHub Models](https://docs.github.com/en/github-models), you'll need a GitHub personal access token with the `models: read` permission. Once you have the token, you can use it with the GitHubProvider: ```python from pydantic_ai import Agent from pydantic_ai.models.openai import OpenAIModel from pydantic_ai.providers.github import GitHubProvider model = OpenAIModel( 'xai/grok-3-mini', # GitHub Models uses prefixed model names provider=GitHubProvider(api_key='your-github-token'), ) agent = Agent(model) ... ``` You can also set the `GITHUB_API_KEY` environment variable: ```bash export GITHUB_API_KEY='your-github-token' ``` GitHub Models supports various model families with different prefixes. You can see the full list on the [GitHub Marketplace](https://github.com/marketplace?type=models) or the public [catalog endpoint](https://models.github.ai/catalog/models). ### Perplexity Follow the Perplexity [getting started](https://docs.perplexity.ai/guides/getting-started) guide to create an API key. Then, you can query the Perplexity API with the following: ```python from pydantic_ai import Agent from pydantic_ai.models.openai import OpenAIModel from pydantic_ai.providers.openai import OpenAIProvider model = OpenAIModel( 'sonar-pro', provider=OpenAIProvider( base_url='https://api.perplexity.ai', api_key='your-perplexity-api-key', ), ) agent = Agent(model) ... ``` ### Fireworks AI Go to [Fireworks.AI](https://fireworks.ai/) and create an API key in your account settings. Once you have the API key, you can use it with the FireworksProvider: ```python from pydantic_ai import Agent from pydantic_ai.models.openai import OpenAIModel from pydantic_ai.providers.fireworks import FireworksProvider model = OpenAIModel( 'accounts/fireworks/models/qwq-32b', # model library available at https://fireworks.ai/models provider=FireworksProvider(api_key='your-fireworks-api-key'), ) agent = Agent(model) ... ``` ### Together AI Go to [Together.ai](https://www.together.ai/) and create an API key in your account settings. Once you have the API key, you can use it with the TogetherProvider: ```python from pydantic_ai import Agent from pydantic_ai.models.openai import OpenAIModel from pydantic_ai.providers.together import TogetherProvider model = OpenAIModel( 'meta-llama/Llama-3.3-70B-Instruct-Turbo-Free', # model library available at https://www.together.ai/models provider=TogetherProvider(api_key='your-together-api-key'), ) agent = Agent(model) ... ``` ### Heroku AI To use [Heroku AI](https://www.heroku.com/ai), you can use the HerokuProvider: ```python from pydantic_ai import Agent from pydantic_ai.models.openai import OpenAIModel from pydantic_ai.providers.heroku import HerokuProvider model = OpenAIModel( 'claude-3-7-sonnet', provider=HerokuProvider(api_key='your-heroku-inference-key'), ) agent = Agent(model) ... ``` You can set the `HEROKU_INFERENCE_KEY` and `HEROKU_INFERENCE_URL` environment variables to set the API key and base URL, respectively: ```bash export HEROKU_INFERENCE_KEY='your-heroku-inference-key' export HEROKU_INFERENCE_URL='https://us.inference.heroku.com' ``` # Graphs # Graphs Don't use a nail gun unless you need a nail gun If PydanticAI [agents](../agents/) are a hammer, and [multi-agent workflows](../multi-agent-applications/) are a sledgehammer, then graphs are a nail gun: - sure, nail guns look cooler than hammers - but nail guns take a lot more setup than hammers - and nail guns don't make you a better builder, they make you a builder with a nail gun - Lastly, (and at the risk of torturing this metaphor), if you're a fan of medieval tools like mallets and untyped Python, you probably won't like nail guns or our approach to graphs. (But then again, if you're not a fan of type hints in Python, you've probably already bounced off PydanticAI to use one of the toy agent frameworks — good luck, and feel free to borrow my sledgehammer when you realize you need it) In short, graphs are a powerful tool, but they're not the right tool for every job. Please consider other [multi-agent approaches](../multi-agent-applications/) before proceeding. If you're not confident a graph-based approach is a good idea, it might be unnecessary. Graphs and finite state machines (FSMs) are a powerful abstraction to model, execute, control and visualize complex workflows. Alongside PydanticAI, we've developed `pydantic-graph` — an async graph and state machine library for Python where nodes and edges are defined using type hints. While this library is developed as part of PydanticAI; it has no dependency on `pydantic-ai` and can be considered as a pure graph-based state machine library. You may find it useful whether or not you're using PydanticAI or even building with GenAI. `pydantic-graph` is designed for advanced users and makes heavy use of Python generics and type hints. It is not designed to be as beginner-friendly as PydanticAI. ## Installation `pydantic-graph` is a required dependency of `pydantic-ai`, and an optional dependency of `pydantic-ai-slim`, see [installation instructions](../install/#slim-install) for more information. You can also install it directly: ```bash pip install pydantic-graph ``` ```bash uv add pydantic-graph ``` ## Graph Types `pydantic-graph` is made up of a few key components: ### GraphRunContext GraphRunContext — The context for the graph run, similar to PydanticAI's RunContext. This holds the state of the graph and dependencies and is passed to nodes when they're run. `GraphRunContext` is generic in the state type of the graph it's used in, StateT. ### End End — return value to indicate the graph run should end. `End` is generic in the graph return type of the graph it's used in, RunEndT. ### Nodes Subclasses of BaseNode define nodes for execution in the graph. Nodes, which are generally dataclasses, generally consist of: - fields containing any parameters required/optional when calling the node - the business logic to execute the node, in the run method - return annotations of the run method, which are read by `pydantic-graph` to determine the outgoing edges of the node Nodes are generic in: - **state**, which must have the same type as the state of graphs they're included in, StateT has a default of `None`, so if you're not using state you can omit this generic parameter, see [stateful graphs](#stateful-graphs) for more information - **deps**, which must have the same type as the deps of the graph they're included in, DepsT has a default of `None`, so if you're not using deps you can omit this generic parameter, see [dependency injection](#dependency-injection) for more information - **graph return type** — this only applies if the node returns End. RunEndT has a default of Never so this generic parameter can be omitted if the node doesn't return `End`, but must be included if it does. Here's an example of a start or intermediate node in a graph — it can't end the run as it doesn't return End: intermediate_node.py ```py from dataclasses import dataclass from pydantic_graph import BaseNode, GraphRunContext @dataclass class MyNode(BaseNode[MyState]): # (1)! foo: int # (2)! async def run( self, ctx: GraphRunContext[MyState], # (3)! ) -> AnotherNode: # (4)! ... return AnotherNode() ``` 1. State in this example is `MyState` (not shown), hence `BaseNode` is parameterized with `MyState`. This node can't end the run, so the `RunEndT` generic parameter is omitted and defaults to `Never`. 1. `MyNode` is a dataclass and has a single field `foo`, an `int`. 1. The `run` method takes a `GraphRunContext` parameter, again parameterized with state `MyState`. 1. The return type of the `run` method is `AnotherNode` (not shown), this is used to determine the outgoing edges of the node. We could extend `MyNode` to optionally end the run if `foo` is divisible by 5: intermediate_or_end_node.py ```py from dataclasses import dataclass from pydantic_graph import BaseNode, End, GraphRunContext @dataclass class MyNode(BaseNode[MyState, None, int]): # (1)! foo: int async def run( self, ctx: GraphRunContext[MyState], ) -> AnotherNode | End[int]: # (2)! if self.foo % 5 == 0: return End(self.foo) else: return AnotherNode() ``` 1. We parameterize the node with the return type (`int` in this case) as well as state. Because generic parameters are positional-only, we have to include `None` as the second parameter representing deps. 1. The return type of the `run` method is now a union of `AnotherNode` and `End[int]`, this allows the node to end the run if `foo` is divisible by 5. ### Graph Graph — this is the execution graph itself, made up of a set of [node classes](#nodes) (i.e., `BaseNode` subclasses). `Graph` is generic in: - **state** the state type of the graph, StateT - **deps** the deps type of the graph, DepsT - **graph return type** the return type of the graph run, RunEndT Here's an example of a simple graph: graph_example.py ```py from __future__ import annotations from dataclasses import dataclass from pydantic_graph import BaseNode, End, Graph, GraphRunContext @dataclass class DivisibleBy5(BaseNode[None, None, int]): # (1)! foo: int async def run( self, ctx: GraphRunContext, ) -> Increment | End[int]: if self.foo % 5 == 0: return End(self.foo) else: return Increment(self.foo) @dataclass class Increment(BaseNode): # (2)! foo: int async def run(self, ctx: GraphRunContext) -> DivisibleBy5: return DivisibleBy5(self.foo + 1) fives_graph = Graph(nodes=[DivisibleBy5, Increment]) # (3)! result = fives_graph.run_sync(DivisibleBy5(4)) # (4)! print(result.output) #> 5 ``` 1. The `DivisibleBy5` node is parameterized with `None` for the state param and `None` for the deps param as this graph doesn't use state or deps, and `int` as it can end the run. 1. The `Increment` node doesn't return `End`, so the `RunEndT` generic parameter is omitted, state can also be omitted as the graph doesn't use state. 1. The graph is created with a sequence of nodes. 1. The graph is run synchronously with run_sync. The initial node is `DivisibleBy5(4)`. Because the graph doesn't use external state or deps, we don't pass `state` or `deps`. *(This example is complete, it can be run "as is" with Python 3.10+)* A [mermaid diagram](#mermaid-diagrams) for this graph can be generated with the following code: graph_example_diagram.py ```py from graph_example import DivisibleBy5, fives_graph fives_graph.mermaid_code(start_node=DivisibleBy5) ``` ``` --- title: fives_graph --- stateDiagram-v2 [*] --> DivisibleBy5 DivisibleBy5 --> Increment DivisibleBy5 --> [*] Increment --> DivisibleBy5 ``` In order to visualize a graph within a `jupyter-notebook`, `IPython.display` needs to be used: jupyter_display_mermaid.py ```python from graph_example import DivisibleBy5, fives_graph from IPython.display import Image, display display(Image(fives_graph.mermaid_image(start_node=DivisibleBy5))) ``` ## Stateful Graphs The "state" concept in `pydantic-graph` provides an optional way to access and mutate an object (often a `dataclass` or Pydantic model) as nodes run in a graph. If you think of Graphs as a production line, then your state is the engine being passed along the line and built up by each node as the graph is run. `pydantic-graph` provides state persistence, with the state recorded after each node is run. (See [State Persistence](#state-persistence).) Here's an example of a graph which represents a vending machine where the user may insert coins and select a product to purchase. vending_machine.py ```python from __future__ import annotations from dataclasses import dataclass from rich.prompt import Prompt from pydantic_graph import BaseNode, End, Graph, GraphRunContext @dataclass class MachineState: # (1)! user_balance: float = 0.0 product: str | None = None @dataclass class InsertCoin(BaseNode[MachineState]): # (3)! async def run(self, ctx: GraphRunContext[MachineState]) -> CoinsInserted: # (16)! return CoinsInserted(float(Prompt.ask('Insert coins'))) # (4)! @dataclass class CoinsInserted(BaseNode[MachineState]): amount: float # (5)! async def run( self, ctx: GraphRunContext[MachineState] ) -> SelectProduct | Purchase: # (17)! ctx.state.user_balance += self.amount # (6)! if ctx.state.product is not None: # (7)! return Purchase(ctx.state.product) else: return SelectProduct() @dataclass class SelectProduct(BaseNode[MachineState]): async def run(self, ctx: GraphRunContext[MachineState]) -> Purchase: return Purchase(Prompt.ask('Select product')) PRODUCT_PRICES = { # (2)! 'water': 1.25, 'soda': 1.50, 'crisps': 1.75, 'chocolate': 2.00, } @dataclass class Purchase(BaseNode[MachineState, None, None]): # (18)! product: str async def run( self, ctx: GraphRunContext[MachineState] ) -> End | InsertCoin | SelectProduct: if price := PRODUCT_PRICES.get(self.product): # (8)! ctx.state.product = self.product # (9)! if ctx.state.user_balance >= price: # (10)! ctx.state.user_balance -= price return End(None) else: diff = price - ctx.state.user_balance print(f'Not enough money for {self.product}, need {diff:0.2f} more') #> Not enough money for crisps, need 0.75 more return InsertCoin() # (11)! else: print(f'No such product: {self.product}, try again') return SelectProduct() # (12)! vending_machine_graph = Graph( # (13)! nodes=[InsertCoin, CoinsInserted, SelectProduct, Purchase] ) async def main(): state = MachineState() # (14)! await vending_machine_graph.run(InsertCoin(), state=state) # (15)! print(f'purchase successful item={state.product} change={state.user_balance:0.2f}') #> purchase successful item=crisps change=0.25 ``` 1. The state of the vending machine is defined as a dataclass with the user's balance and the product they've selected, if any. 1. A dictionary of products mapped to prices. 1. The `InsertCoin` node, BaseNode is parameterized with `MachineState` as that's the state used in this graph. 1. The `InsertCoin` node prompts the user to insert coins. We keep things simple by just entering a monetary amount as a float. Before you start thinking this is a toy too since it's using rich's Prompt.ask within nodes, see [below](#example-human-in-the-loop) for how control flow can be managed when nodes require external input. 1. The `CoinsInserted` node; again this is a dataclass with one field `amount`. 1. Update the user's balance with the amount inserted. 1. If the user has already selected a product, go to `Purchase`, otherwise go to `SelectProduct`. 1. In the `Purchase` node, look up the price of the product if the user entered a valid product. 1. If the user did enter a valid product, set the product in the state so we don't revisit `SelectProduct`. 1. If the balance is enough to purchase the product, adjust the balance to reflect the purchase and return End to end the graph. We're not using the run return type, so we call `End` with `None`. 1. If the balance is insufficient, go to `InsertCoin` to prompt the user to insert more coins. 1. If the product is invalid, go to `SelectProduct` to prompt the user to select a product again. 1. The graph is created by passing a list of nodes to Graph. Order of nodes is not important, but it can affect how [diagrams](#mermaid-diagrams) are displayed. 1. Initialize the state. This will be passed to the graph run and mutated as the graph runs. 1. Run the graph with the initial state. Since the graph can be run from any node, we must pass the start node — in this case, `InsertCoin`. Graph.run returns a GraphRunResult that provides the final data and a history of the run. 1. The return type of the node's run method is important as it is used to determine the outgoing edges of the node. This information in turn is used to render [mermaid diagrams](#mermaid-diagrams) and is enforced at runtime to detect misbehavior as soon as possible. 1. The return type of `CoinsInserted`'s run method is a union, meaning multiple outgoing edges are possible. 1. Unlike other nodes, `Purchase` can end the run, so the RunEndT generic parameter must be set. In this case it's `None` since the graph run return type is `None`. *(This example is complete, it can be run "as is" with Python 3.10+ — you'll need to add `asyncio.run(main())` to run `main`)* A [mermaid diagram](#mermaid-diagrams) for this graph can be generated with the following code: vending_machine_diagram.py ```py from vending_machine import InsertCoin, vending_machine_graph vending_machine_graph.mermaid_code(start_node=InsertCoin) ``` The diagram generated by the above code is: ``` --- title: vending_machine_graph --- stateDiagram-v2 [*] --> InsertCoin InsertCoin --> CoinsInserted CoinsInserted --> SelectProduct CoinsInserted --> Purchase SelectProduct --> Purchase Purchase --> InsertCoin Purchase --> SelectProduct Purchase --> [*] ``` See [below](#mermaid-diagrams) for more information on generating diagrams. ## GenAI Example So far we haven't shown an example of a Graph that actually uses PydanticAI or GenAI at all. In this example, one agent generates a welcome email to a user and the other agent provides feedback on the email. This graph has a very simple structure: ``` --- title: feedback_graph --- stateDiagram-v2 [*] --> WriteEmail WriteEmail --> Feedback Feedback --> WriteEmail Feedback --> [*] ``` genai_email_feedback.py ```python from __future__ import annotations as _annotations from dataclasses import dataclass, field from pydantic import BaseModel, EmailStr from pydantic_ai import Agent, format_as_xml from pydantic_ai.messages import ModelMessage from pydantic_graph import BaseNode, End, Graph, GraphRunContext @dataclass class User: name: str email: EmailStr interests: list[str] @dataclass class Email: subject: str body: str @dataclass class State: user: User write_agent_messages: list[ModelMessage] = field(default_factory=list) email_writer_agent = Agent( 'google-vertex:gemini-1.5-pro', output_type=Email, system_prompt='Write a welcome email to our tech blog.', ) @dataclass class WriteEmail(BaseNode[State]): email_feedback: str | None = None async def run(self, ctx: GraphRunContext[State]) -> Feedback: if self.email_feedback: prompt = ( f'Rewrite the email for the user:\n' f'{format_as_xml(ctx.state.user)}\n' f'Feedback: {self.email_feedback}' ) else: prompt = ( f'Write a welcome email for the user:\n' f'{format_as_xml(ctx.state.user)}' ) result = await email_writer_agent.run( prompt, message_history=ctx.state.write_agent_messages, ) ctx.state.write_agent_messages += result.new_messages() return Feedback(result.output) class EmailRequiresWrite(BaseModel): feedback: str class EmailOk(BaseModel): pass feedback_agent = Agent[None, EmailRequiresWrite | EmailOk]( 'openai:gpt-4o', output_type=EmailRequiresWrite | EmailOk, # type: ignore system_prompt=( 'Review the email and provide feedback, email must reference the users specific interests.' ), ) @dataclass class Feedback(BaseNode[State, None, Email]): email: Email async def run( self, ctx: GraphRunContext[State], ) -> WriteEmail | End[Email]: prompt = format_as_xml({'user': ctx.state.user, 'email': self.email}) result = await feedback_agent.run(prompt) if isinstance(result.output, EmailRequiresWrite): return WriteEmail(email_feedback=result.output.feedback) else: return End(self.email) async def main(): user = User( name='John Doe', email='john.joe@example.com', interests=['Haskel', 'Lisp', 'Fortran'], ) state = State(user) feedback_graph = Graph(nodes=(WriteEmail, Feedback)) result = await feedback_graph.run(WriteEmail(), state=state) print(result.output) """ Email( subject='Welcome to our tech blog!', body='Hello John, Welcome to our tech blog! ...', ) """ ``` *(This example is complete, it can be run "as is" with Python 3.10+ — you'll need to add `asyncio.run(main())` to run `main`)* ## Iterating Over a Graph ### Using `Graph.iter` for `async for` iteration Sometimes you want direct control or insight into each node as the graph executes. The easiest way to do that is with the Graph.iter method, which returns a **context manager** that yields a GraphRun object. The `GraphRun` is an async-iterable over the nodes of your graph, allowing you to record or modify them as they execute. Here's an example: count_down.py ```python from __future__ import annotations as _annotations from dataclasses import dataclass from pydantic_graph import Graph, BaseNode, End, GraphRunContext @dataclass class CountDownState: counter: int @dataclass class CountDown(BaseNode[CountDownState, None, int]): async def run(self, ctx: GraphRunContext[CountDownState]) -> CountDown | End[int]: if ctx.state.counter <= 0: return End(ctx.state.counter) ctx.state.counter -= 1 return CountDown() count_down_graph = Graph(nodes=[CountDown]) async def main(): state = CountDownState(counter=3) async with count_down_graph.iter(CountDown(), state=state) as run: # (1)! async for node in run: # (2)! print('Node:', node) #> Node: CountDown() #> Node: CountDown() #> Node: CountDown() #> Node: CountDown() #> Node: End(data=0) print('Final output:', run.result.output) # (3)! #> Final output: 0 ``` 1. `Graph.iter(...)` returns a GraphRun. 1. Here, we step through each node as it is executed. 1. Once the graph returns an End, the loop ends, and `run.result` becomes a GraphRunResult containing the final outcome (`0` here). ### Using `GraphRun.next(node)` manually Alternatively, you can drive iteration manually with the GraphRun.next method, which allows you to pass in whichever node you want to run next. You can modify or selectively skip nodes this way. Below is a contrived example that stops whenever the counter is at 2, ignoring any node runs beyond that: count_down_next.py ```python from pydantic_graph import End, FullStatePersistence from count_down import CountDown, CountDownState, count_down_graph async def main(): state = CountDownState(counter=5) persistence = FullStatePersistence() # (7)! async with count_down_graph.iter( CountDown(), state=state, persistence=persistence ) as run: node = run.next_node # (1)! while not isinstance(node, End): # (2)! print('Node:', node) #> Node: CountDown() #> Node: CountDown() #> Node: CountDown() #> Node: CountDown() if state.counter == 2: break # (3)! node = await run.next(node) # (4)! print(run.result) # (5)! #> None for step in persistence.history: # (6)! print('History Step:', step.state, step.state) #> History Step: CountDownState(counter=5) CountDownState(counter=5) #> History Step: CountDownState(counter=4) CountDownState(counter=4) #> History Step: CountDownState(counter=3) CountDownState(counter=3) #> History Step: CountDownState(counter=2) CountDownState(counter=2) ``` 1. We start by grabbing the first node that will be run in the agent's graph. 1. The agent run is finished once an `End` node has been produced; instances of `End` cannot be passed to `next`. 1. If the user decides to stop early, we break out of the loop. The graph run won't have a real final result in that case (`run.result` remains `None`). 1. At each step, we call `await run.next(node)` to run it and get the next node (or an `End`). 1. Because we did not continue the run until it finished, the `result` is not set. 1. The run's history is still populated with the steps we executed so far. 1. Use FullStatePersistence so we can show the history of the run, see [State Persistence](#state-persistence) below for more information. ## State Persistence One of the biggest benefits of finite state machine (FSM) graphs is how they simplify the handling of interrupted execution. This might happen for a variety of reasons: - the state machine logic might fundamentally need to be paused — e.g. the returns workflow for an e-commerce order needs to wait for the item to be posted to the returns center or because execution of the next node needs input from a user so needs to wait for a new http request, - the execution takes so long that the entire graph can't reliably be executed in a single continuous run — e.g. a deep research agent that might take hours to run, - you want to run multiple graph nodes in parallel in different processes / hardware instances (note: parallel node execution is not yet supported in `pydantic-graph`, see [#704](https://github.com/pydantic/pydantic-ai/issues/704)). Trying to make a conventional control flow (i.e., boolean logic and nested function calls) implementation compatible with these usage scenarios generally results in brittle and over-complicated spaghetti code, with the logic required to interrupt and resume execution dominating the implementation. To allow graph runs to be interrupted and resumed, `pydantic-graph` provides state persistence — a system for snapshotting the state of a graph run before and after each node is run, allowing a graph run to be resumed from any point in the graph. `pydantic-graph` includes three state persistence implementations: - SimpleStatePersistence — Simple in memory state persistence that just hold the latest snapshot. If no state persistence implementation is provided when running a graph, this is used by default. - FullStatePersistence — In memory state persistence that hold a list of snapshots. - FileStatePersistence — File-based state persistence that saves snapshots to a JSON file. In production applications, developers should implement their own state persistence by subclassing BaseStatePersistence abstract base class, which might persist runs in a relational database like PostgresQL. At a high level the role of `StatePersistence` implementations is to store and retrieve NodeSnapshot and EndSnapshot objects. graph.iter_from_persistence() may be used to run the graph based on the state stored in persistence. We can run the `count_down_graph` from [above](#iterating-over-a-graph), using graph.iter_from_persistence() and FileStatePersistence. As you can see in this code, `run_node` requires no external application state (apart from state persistence) to be run, meaning graphs can easily be executed by distributed execution and queueing systems. count_down_from_persistence.py ```python from pathlib import Path from pydantic_graph import End from pydantic_graph.persistence.file import FileStatePersistence from count_down import CountDown, CountDownState, count_down_graph async def main(): run_id = 'run_abc123' persistence = FileStatePersistence(Path(f'count_down_{run_id}.json')) # (1)! state = CountDownState(counter=5) await count_down_graph.initialize( # (2)! CountDown(), state=state, persistence=persistence ) done = False while not done: done = await run_node(run_id) async def run_node(run_id: str) -> bool: # (3)! persistence = FileStatePersistence(Path(f'count_down_{run_id}.json')) async with count_down_graph.iter_from_persistence(persistence) as run: # (4)! node_or_end = await run.next() # (5)! print('Node:', node_or_end) #> Node: CountDown() #> Node: CountDown() #> Node: CountDown() #> Node: CountDown() #> Node: CountDown() #> Node: End(data=0) return isinstance(node_or_end, End) # (6)! ``` 1. Create a FileStatePersistence to use to start the graph. 1. Call graph.initialize() to set the initial graph state in the persistence object. 1. `run_node` is a pure function that doesn't need access to any other process state to run the next node of the graph, except the ID of the run. 1. Call graph.iter_from_persistence() create a GraphRun object that will run the next node of the graph from the state stored in persistence. This will return either a node or an `End` object. 1. graph.run() will return either a node or an End object. 1. Check if the node is an End object, if it is, the graph run is complete. *(This example is complete, it can be run "as is" with Python 3.10+ — you'll need to add `asyncio.run(main())` to run `main`)* ### Example: Human in the loop. As noted above, state persistence allows graphs to be interrupted and resumed. One use case of this is to allow user input to continue. In this example, an AI asks the user a question, the user provides an answer, the AI evaluates the answer and ends if the user got it right or asks another question if they got it wrong. Instead of running the entire graph in a single process invocation, we run the graph by running the process repeatedly, optionally providing an answer to the question as a command line argument. `ai_q_and_a_graph.py` — `question_graph` definition ai_q_and_a_graph.py ```python from __future__ import annotations as _annotations from dataclasses import dataclass, field from pydantic import BaseModel from pydantic_graph import ( BaseNode, End, Graph, GraphRunContext, ) from pydantic_ai import Agent, format_as_xml from pydantic_ai.messages import ModelMessage ask_agent = Agent('openai:gpt-4o', output_type=str, instrument=True) @dataclass class QuestionState: question: str | None = None ask_agent_messages: list[ModelMessage] = field(default_factory=list) evaluate_agent_messages: list[ModelMessage] = field(default_factory=list) @dataclass class Ask(BaseNode[QuestionState]): async def run(self, ctx: GraphRunContext[QuestionState]) -> Answer: result = await ask_agent.run( 'Ask a simple question with a single correct answer.', message_history=ctx.state.ask_agent_messages, ) ctx.state.ask_agent_messages += result.new_messages() ctx.state.question = result.output return Answer(result.output) @dataclass class Answer(BaseNode[QuestionState]): question: str async def run(self, ctx: GraphRunContext[QuestionState]) -> Evaluate: answer = input(f'{self.question}: ') return Evaluate(answer) class EvaluationResult(BaseModel, use_attribute_docstrings=True): correct: bool """Whether the answer is correct.""" comment: str """Comment on the answer, reprimand the user if the answer is wrong.""" evaluate_agent = Agent( 'openai:gpt-4o', output_type=EvaluationResult, system_prompt='Given a question and answer, evaluate if the answer is correct.', ) @dataclass class Evaluate(BaseNode[QuestionState, None, str]): answer: str async def run( self, ctx: GraphRunContext[QuestionState], ) -> End[str] | Reprimand: assert ctx.state.question is not None result = await evaluate_agent.run( format_as_xml({'question': ctx.state.question, 'answer': self.answer}), message_history=ctx.state.evaluate_agent_messages, ) ctx.state.evaluate_agent_messages += result.new_messages() if result.output.correct: return End(result.output.comment) else: return Reprimand(result.output.comment) @dataclass class Reprimand(BaseNode[QuestionState]): comment: str async def run(self, ctx: GraphRunContext[QuestionState]) -> Ask: print(f'Comment: {self.comment}') ctx.state.question = None return Ask() question_graph = Graph( nodes=(Ask, Answer, Evaluate, Reprimand), state_type=QuestionState ) ``` *(This example is complete, it can be run "as is" with Python 3.10+)* ai_q_and_a_run.py ```python import sys from pathlib import Path from pydantic_graph import End from pydantic_graph.persistence.file import FileStatePersistence from pydantic_ai.messages import ModelMessage # noqa: F401 from ai_q_and_a_graph import Ask, question_graph, Evaluate, QuestionState, Answer async def main(): answer: str | None = sys.argv[1] if len(sys.argv) > 1 else None # (1)! persistence = FileStatePersistence(Path('question_graph.json')) # (2)! persistence.set_graph_types(question_graph) # (3)! if snapshot := await persistence.load_next(): # (4)! state = snapshot.state assert answer is not None node = Evaluate(answer) else: state = QuestionState() node = Ask() # (5)! async with question_graph.iter(node, state=state, persistence=persistence) as run: while True: node = await run.next() # (6)! if isinstance(node, End): # (7)! print('END:', node.data) history = await persistence.load_all() # (8)! print([e.node for e in history]) break elif isinstance(node, Answer): # (9)! print(node.question) #> What is the capital of France? break # otherwise just continue ``` 1. Get the user's answer from the command line, if provided. See [question graph example](../examples/question-graph/) for a complete example. 1. Create a state persistence instance the `'question_graph.json'` file may or may not already exist. 1. Since we're using the persistence interface outside a graph, we need to call set_graph_types to set the graph generic types `StateT` and `RunEndT` for the persistence instance. This is necessary to allow the persistence instance to know how to serialize and deserialize graph nodes. 1. If we're run the graph before, load_next will return a snapshot of the next node to run, here we use `state` from that snapshot, and create a new `Evaluate` node with the answer provided on the command line. 1. If the graph hasn't been run before, we create a new `QuestionState` and start with the `Ask` node. 1. Call GraphRun.next() to run the node. This will return either a node or an `End` object. 1. If the node is an `End` object, the graph run is complete. The `data` field of the `End` object contains the comment returned by the `evaluate_agent` about the correct answer. 1. To demonstrate the state persistence, we call load_all to get all the snapshots from the persistence instance. This will return a list of Snapshot objects. 1. If the node is an `Answer` object, we print the question and break out of the loop to end the process and wait for user input. *(This example is complete, it can be run "as is" with Python 3.10+ — you'll need to add `asyncio.run(main())` to run `main`)* For a complete example of this graph, see the [question graph example](../examples/question-graph/). ## Dependency Injection As with PydanticAI, `pydantic-graph` supports dependency injection via a generic parameter on Graph and BaseNode, and the GraphRunContext.deps field. As an example of dependency injection, let's modify the `DivisibleBy5` example [above](#graph) to use a ProcessPoolExecutor to run the compute load in a separate process (this is a contrived example, `ProcessPoolExecutor` wouldn't actually improve performance in this example): deps_example.py ```py from __future__ import annotations import asyncio from concurrent.futures import ProcessPoolExecutor from dataclasses import dataclass from pydantic_graph import BaseNode, End, FullStatePersistence, Graph, GraphRunContext @dataclass class GraphDeps: executor: ProcessPoolExecutor @dataclass class DivisibleBy5(BaseNode[None, GraphDeps, int]): foo: int async def run( self, ctx: GraphRunContext[None, GraphDeps], ) -> Increment | End[int]: if self.foo % 5 == 0: return End(self.foo) else: return Increment(self.foo) @dataclass class Increment(BaseNode[None, GraphDeps]): foo: int async def run(self, ctx: GraphRunContext[None, GraphDeps]) -> DivisibleBy5: loop = asyncio.get_running_loop() compute_result = await loop.run_in_executor( ctx.deps.executor, self.compute, ) return DivisibleBy5(compute_result) def compute(self) -> int: return self.foo + 1 fives_graph = Graph(nodes=[DivisibleBy5, Increment]) async def main(): with ProcessPoolExecutor() as executor: deps = GraphDeps(executor) result = await fives_graph.run(DivisibleBy5(3), deps=deps, persistence=FullStatePersistence()) print(result.output) #> 5 # the full history is quite verbose (see below), so we'll just print the summary print([item.node for item in result.persistence.history]) """ [ DivisibleBy5(foo=3), Increment(foo=3), DivisibleBy5(foo=4), Increment(foo=4), DivisibleBy5(foo=5), End(data=5), ] """ ``` *(This example is complete, it can be run "as is" with Python 3.10+ — you'll need to add `asyncio.run(main())` to run `main`)* ## Mermaid Diagrams Pydantic Graph can generate [mermaid](https://mermaid.js.org/) [`stateDiagram-v2`](https://mermaid.js.org/syntax/stateDiagram.html) diagrams for graphs, as shown above. These diagrams can be generated with: - Graph.mermaid_code to generate the mermaid code for a graph - Graph.mermaid_image to generate an image of the graph using [mermaid.ink](https://mermaid.ink/) - Graph.mermaid_save to generate an image of the graph using [mermaid.ink](https://mermaid.ink/) and save it to a file Beyond the diagrams shown above, you can also customize mermaid diagrams with the following options: - Edge allows you to apply a label to an edge - BaseNode.docstring_notes and BaseNode.get_note allows you to add notes to nodes - The highlighted_nodes parameter allows you to highlight specific node(s) in the diagram Putting that together, we can edit the last [`ai_q_and_a_graph.py`](#example-human-in-the-loop) example to: - add labels to some edges - add a note to the `Ask` node - highlight the `Answer` node - save the diagram as a `PNG` image to file ai_q_and_a_graph_extra.py ```python ... from typing import Annotated from pydantic_graph import BaseNode, End, Graph, GraphRunContext, Edge ... @dataclass class Ask(BaseNode[QuestionState]): """Generate question using GPT-4o.""" docstring_notes = True async def run( self, ctx: GraphRunContext[QuestionState] ) -> Annotated[Answer, Edge(label='Ask the question')]: ... ... @dataclass class Evaluate(BaseNode[QuestionState]): answer: str async def run( self, ctx: GraphRunContext[QuestionState], ) -> Annotated[End[str], Edge(label='success')] | Reprimand: ... ... question_graph.mermaid_save('image.png', highlighted_nodes=[Answer]) ``` *(This example is not complete and cannot be run directly)* This would generate an image that looks like this: ``` --- title: question_graph --- stateDiagram-v2 Ask --> Answer: Ask the question note right of Ask Judge the answer. Decide on next step. end note Answer --> Evaluate Evaluate --> Reprimand Evaluate --> [*]: success Reprimand --> Ask classDef highlighted fill:#fdff32 class Answer highlighted ``` ### Setting Direction of the State Diagram You can specify the direction of the state diagram using one of the following values: - `'TB'`: Top to bottom, the diagram flows vertically from top to bottom. - `'LR'`: Left to right, the diagram flows horizontally from left to right. - `'RL'`: Right to left, the diagram flows horizontally from right to left. - `'BT'`: Bottom to top, the diagram flows vertically from bottom to top. Here is an example of how to do this using 'Left to Right' (LR) instead of the default 'Top to Bottom' (TB): vending_machine_diagram.py ```py from vending_machine import InsertCoin, vending_machine_graph vending_machine_graph.mermaid_code(start_node=InsertCoin, direction='LR') ``` ``` --- title: vending_machine_graph --- stateDiagram-v2 direction LR [*] --> InsertCoin InsertCoin --> CoinsInserted CoinsInserted --> SelectProduct CoinsInserted --> Purchase SelectProduct --> Purchase Purchase --> InsertCoin Purchase --> SelectProduct Purchase --> [*] ``` # API Reference # `pydantic_ai.agent` ### Agent Bases: `Generic[AgentDepsT, OutputDataT]` Class for defining "agents" - a way to have a specific type of "conversation" with an LLM. Agents are generic in the dependency type they take AgentDepsT and the output type they return, OutputDataT. By default, if neither generic parameter is customised, agents have type `Agent[None, str]`. Minimal usage example: ```python from pydantic_ai import Agent agent = Agent('openai:gpt-4o') result = agent.run_sync('What is the capital of France?') print(result.output) #> Paris ``` Source code in `pydantic_ai_slim/pydantic_ai/agent.py` ````python @final @dataclasses.dataclass(init=False) class Agent(Generic[AgentDepsT, OutputDataT]): """Class for defining "agents" - a way to have a specific type of "conversation" with an LLM. Agents are generic in the dependency type they take [`AgentDepsT`][pydantic_ai.tools.AgentDepsT] and the output type they return, [`OutputDataT`][pydantic_ai.output.OutputDataT]. By default, if neither generic parameter is customised, agents have type `Agent[None, str]`. Minimal usage example: ```python from pydantic_ai import Agent agent = Agent('openai:gpt-4o') result = agent.run_sync('What is the capital of France?') print(result.output) #> Paris ``` """ model: models.Model | models.KnownModelName | str | None """The default model configured for this agent. We allow `str` here since the actual list of allowed models changes frequently. """ name: str | None """The name of the agent, used for logging. If `None`, we try to infer the agent name from the call frame when the agent is first run. """ end_strategy: EndStrategy """Strategy for handling tool calls when a final result is found.""" model_settings: ModelSettings | None """Optional model request settings to use for this agents's runs, by default. Note, if `model_settings` is provided by `run`, `run_sync`, or `run_stream`, those settings will be merged with this value, with the runtime argument taking priority. """ output_type: OutputSpec[OutputDataT] """ The type of data output by agent runs, used to validate the data returned by the model, defaults to `str`. """ instrument: InstrumentationSettings | bool | None """Options to automatically instrument with OpenTelemetry.""" _instrument_default: ClassVar[InstrumentationSettings | bool] = False _deps_type: type[AgentDepsT] = dataclasses.field(repr=False) _deprecated_result_tool_name: str | None = dataclasses.field(repr=False) _deprecated_result_tool_description: str | None = dataclasses.field(repr=False) _output_schema: _output.BaseOutputSchema[OutputDataT] = dataclasses.field(repr=False) _output_validators: list[_output.OutputValidator[AgentDepsT, OutputDataT]] = dataclasses.field(repr=False) _instructions: str | None = dataclasses.field(repr=False) _instructions_functions: list[_system_prompt.SystemPromptRunner[AgentDepsT]] = dataclasses.field(repr=False) _system_prompts: tuple[str, ...] = dataclasses.field(repr=False) _system_prompt_functions: list[_system_prompt.SystemPromptRunner[AgentDepsT]] = dataclasses.field(repr=False) _system_prompt_dynamic_functions: dict[str, _system_prompt.SystemPromptRunner[AgentDepsT]] = dataclasses.field( repr=False ) _prepare_tools: ToolsPrepareFunc[AgentDepsT] | None = dataclasses.field(repr=False) _function_tools: dict[str, Tool[AgentDepsT]] = dataclasses.field(repr=False) _mcp_servers: Sequence[MCPServer] = dataclasses.field(repr=False) _default_retries: int = dataclasses.field(repr=False) _max_result_retries: int = dataclasses.field(repr=False) @overload def __init__( self, model: models.Model | models.KnownModelName | str | None = None, *, output_type: OutputSpec[OutputDataT] = str, instructions: str | _system_prompt.SystemPromptFunc[AgentDepsT] | Sequence[str | _system_prompt.SystemPromptFunc[AgentDepsT]] | None = None, system_prompt: str | Sequence[str] = (), deps_type: type[AgentDepsT] = NoneType, name: str | None = None, model_settings: ModelSettings | None = None, retries: int = 1, output_retries: int | None = None, tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (), prepare_tools: ToolsPrepareFunc[AgentDepsT] | None = None, mcp_servers: Sequence[MCPServer] = (), defer_model_check: bool = False, end_strategy: EndStrategy = 'early', instrument: InstrumentationSettings | bool | None = None, history_processors: Sequence[HistoryProcessor[AgentDepsT]] | None = None, ) -> None: ... @overload @deprecated( '`result_type`, `result_tool_name`, `result_tool_description` & `result_retries` are deprecated, use `output_type` instead. `result_retries` is deprecated, use `output_retries` instead.' ) def __init__( self, model: models.Model | models.KnownModelName | str | None = None, *, result_type: type[OutputDataT] = str, instructions: str | _system_prompt.SystemPromptFunc[AgentDepsT] | Sequence[str | _system_prompt.SystemPromptFunc[AgentDepsT]] | None = None, system_prompt: str | Sequence[str] = (), deps_type: type[AgentDepsT] = NoneType, name: str | None = None, model_settings: ModelSettings | None = None, retries: int = 1, result_tool_name: str = _output.DEFAULT_OUTPUT_TOOL_NAME, result_tool_description: str | None = None, result_retries: int | None = None, tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (), prepare_tools: ToolsPrepareFunc[AgentDepsT] | None = None, mcp_servers: Sequence[MCPServer] = (), defer_model_check: bool = False, end_strategy: EndStrategy = 'early', instrument: InstrumentationSettings | bool | None = None, history_processors: Sequence[HistoryProcessor[AgentDepsT]] | None = None, ) -> None: ... def __init__( self, model: models.Model | models.KnownModelName | str | None = None, *, # TODO change this back to `output_type: _output.OutputType[OutputDataT] = str,` when we remove the overloads output_type: Any = str, instructions: str | _system_prompt.SystemPromptFunc[AgentDepsT] | Sequence[str | _system_prompt.SystemPromptFunc[AgentDepsT]] | None = None, system_prompt: str | Sequence[str] = (), deps_type: type[AgentDepsT] = NoneType, name: str | None = None, model_settings: ModelSettings | None = None, retries: int = 1, output_retries: int | None = None, tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (), prepare_tools: ToolsPrepareFunc[AgentDepsT] | None = None, mcp_servers: Sequence[MCPServer] = (), defer_model_check: bool = False, end_strategy: EndStrategy = 'early', instrument: InstrumentationSettings | bool | None = None, history_processors: Sequence[HistoryProcessor[AgentDepsT]] | None = None, **_deprecated_kwargs: Any, ): """Create an agent. Args: model: The default model to use for this agent, if not provide, you must provide the model when calling it. We allow `str` here since the actual list of allowed models changes frequently. output_type: The type of the output data, used to validate the data returned by the model, defaults to `str`. instructions: Instructions to use for this agent, you can also register instructions via a function with [`instructions`][pydantic_ai.Agent.instructions]. system_prompt: Static system prompts to use for this agent, you can also register system prompts via a function with [`system_prompt`][pydantic_ai.Agent.system_prompt]. deps_type: The type used for dependency injection, this parameter exists solely to allow you to fully parameterize the agent, and therefore get the best out of static type checking. If you're not using deps, but want type checking to pass, you can set `deps=None` to satisfy Pyright or add a type hint `: Agent[None, ]`. name: The name of the agent, used for logging. If `None`, we try to infer the agent name from the call frame when the agent is first run. model_settings: Optional model request settings to use for this agent's runs, by default. retries: The default number of retries to allow before raising an error. output_retries: The maximum number of retries to allow for result validation, defaults to `retries`. tools: Tools to register with the agent, you can also register tools via the decorators [`@agent.tool`][pydantic_ai.Agent.tool] and [`@agent.tool_plain`][pydantic_ai.Agent.tool_plain]. prepare_tools: custom method to prepare the tool definition of all tools for each step. This is useful if you want to customize the definition of multiple tools or you want to register a subset of tools for a given step. See [`ToolsPrepareFunc`][pydantic_ai.tools.ToolsPrepareFunc] mcp_servers: MCP servers to register with the agent. You should register a [`MCPServer`][pydantic_ai.mcp.MCPServer] for each server you want the agent to connect to. defer_model_check: by default, if you provide a [named][pydantic_ai.models.KnownModelName] model, it's evaluated to create a [`Model`][pydantic_ai.models.Model] instance immediately, which checks for the necessary environment variables. Set this to `false` to defer the evaluation until the first run. Useful if you want to [override the model][pydantic_ai.Agent.override] for testing. end_strategy: Strategy for handling tool calls that are requested alongside a final result. See [`EndStrategy`][pydantic_ai.agent.EndStrategy] for more information. instrument: Set to True to automatically instrument with OpenTelemetry, which will use Logfire if it's configured. Set to an instance of [`InstrumentationSettings`][pydantic_ai.agent.InstrumentationSettings] to customize. If this isn't set, then the last value set by [`Agent.instrument_all()`][pydantic_ai.Agent.instrument_all] will be used, which defaults to False. See the [Debugging and Monitoring guide](https://ai.pydantic.dev/logfire/) for more info. history_processors: Optional list of callables to process the message history before sending it to the model. Each processor takes a list of messages and returns a modified list of messages. Processors can be sync or async and are applied in sequence. """ if model is None or defer_model_check: self.model = model else: self.model = models.infer_model(model) self.end_strategy = end_strategy self.name = name self.model_settings = model_settings if 'result_type' in _deprecated_kwargs: if output_type is not str: # pragma: no cover raise TypeError('`result_type` and `output_type` cannot be set at the same time.') warnings.warn('`result_type` is deprecated, use `output_type` instead', DeprecationWarning, stacklevel=2) output_type = _deprecated_kwargs.pop('result_type') self.output_type = output_type self.instrument = instrument self._deps_type = deps_type self._deprecated_result_tool_name = _deprecated_kwargs.pop('result_tool_name', None) if self._deprecated_result_tool_name is not None: warnings.warn( '`result_tool_name` is deprecated, use `output_type` with `ToolOutput` instead', DeprecationWarning, stacklevel=2, ) self._deprecated_result_tool_description = _deprecated_kwargs.pop('result_tool_description', None) if self._deprecated_result_tool_description is not None: warnings.warn( '`result_tool_description` is deprecated, use `output_type` with `ToolOutput` instead', DeprecationWarning, stacklevel=2, ) result_retries = _deprecated_kwargs.pop('result_retries', None) if result_retries is not None: if output_retries is not None: # pragma: no cover raise TypeError('`output_retries` and `result_retries` cannot be set at the same time.') warnings.warn( '`result_retries` is deprecated, use `max_result_retries` instead', DeprecationWarning, stacklevel=2 ) output_retries = result_retries default_output_mode = ( self.model.profile.default_structured_output_mode if isinstance(self.model, models.Model) else None ) _utils.validate_empty_kwargs(_deprecated_kwargs) self._output_schema = _output.OutputSchema[OutputDataT].build( output_type, default_mode=default_output_mode, name=self._deprecated_result_tool_name, description=self._deprecated_result_tool_description, ) self._output_validators = [] self._instructions = '' self._instructions_functions = [] if isinstance(instructions, (str, Callable)): instructions = [instructions] for instruction in instructions or []: if isinstance(instruction, str): self._instructions += instruction + '\n' else: self._instructions_functions.append(_system_prompt.SystemPromptRunner(instruction)) self._instructions = self._instructions.strip() or None self._system_prompts = (system_prompt,) if isinstance(system_prompt, str) else tuple(system_prompt) self._system_prompt_functions = [] self._system_prompt_dynamic_functions = {} self._function_tools = {} self._default_retries = retries self._max_result_retries = output_retries if output_retries is not None else retries self._mcp_servers = mcp_servers self._prepare_tools = prepare_tools self.history_processors = history_processors or [] for tool in tools: if isinstance(tool, Tool): self._register_tool(tool) else: self._register_tool(Tool(tool)) self._override_deps: ContextVar[_utils.Option[AgentDepsT]] = ContextVar('_override_deps', default=None) self._override_model: ContextVar[_utils.Option[models.Model]] = ContextVar('_override_model', default=None) @staticmethod def instrument_all(instrument: InstrumentationSettings | bool = True) -> None: """Set the instrumentation options for all agents where `instrument` is not set.""" Agent._instrument_default = instrument @overload async def run( self, user_prompt: str | Sequence[_messages.UserContent] | None = None, *, output_type: None = None, message_history: list[_messages.ModelMessage] | None = None, model: models.Model | models.KnownModelName | str | None = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.Usage | None = None, infer_name: bool = True, ) -> AgentRunResult[OutputDataT]: ... @overload async def run( self, user_prompt: str | Sequence[_messages.UserContent] | None = None, *, output_type: OutputSpec[RunOutputDataT], message_history: list[_messages.ModelMessage] | None = None, model: models.Model | models.KnownModelName | str | None = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.Usage | None = None, infer_name: bool = True, ) -> AgentRunResult[RunOutputDataT]: ... @overload @deprecated('`result_type` is deprecated, use `output_type` instead.') async def run( self, user_prompt: str | Sequence[_messages.UserContent] | None = None, *, result_type: type[RunOutputDataT], message_history: list[_messages.ModelMessage] | None = None, model: models.Model | models.KnownModelName | str | None = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.Usage | None = None, infer_name: bool = True, ) -> AgentRunResult[RunOutputDataT]: ... async def run( self, user_prompt: str | Sequence[_messages.UserContent] | None = None, *, output_type: OutputSpec[RunOutputDataT] | None = None, message_history: list[_messages.ModelMessage] | None = None, model: models.Model | models.KnownModelName | str | None = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.Usage | None = None, infer_name: bool = True, **_deprecated_kwargs: Never, ) -> AgentRunResult[Any]: """Run the agent with a user prompt in async mode. This method builds an internal agent graph (using system prompts, tools and result schemas) and then runs the graph to completion. The result of the run is returned. Example: ```python from pydantic_ai import Agent agent = Agent('openai:gpt-4o') async def main(): agent_run = await agent.run('What is the capital of France?') print(agent_run.output) #> Paris ``` Args: user_prompt: User input to start/continue the conversation. output_type: Custom output type to use for this run, `output_type` may only be used if the agent has no output validators since output validators would expect an argument that matches the agent's output type. message_history: History of the conversation so far. model: Optional model to use for this run, required if `model` was not set when creating the agent. deps: Optional dependencies to use for this run. model_settings: Optional settings to use for this model's request. usage_limits: Optional limits on model request count or token usage. usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. infer_name: Whether to try to infer the agent name from the call frame if it's not set. Returns: The result of the run. """ if infer_name and self.name is None: self._infer_name(inspect.currentframe()) if 'result_type' in _deprecated_kwargs: # pragma: no cover if output_type is not str: raise TypeError('`result_type` and `output_type` cannot be set at the same time.') warnings.warn('`result_type` is deprecated, use `output_type` instead.', DeprecationWarning, stacklevel=2) output_type = _deprecated_kwargs.pop('result_type') _utils.validate_empty_kwargs(_deprecated_kwargs) async with self.iter( user_prompt=user_prompt, output_type=output_type, message_history=message_history, model=model, deps=deps, model_settings=model_settings, usage_limits=usage_limits, usage=usage, ) as agent_run: async for _ in agent_run: pass assert agent_run.result is not None, 'The graph run did not finish properly' return agent_run.result @overload def iter( self, user_prompt: str | Sequence[_messages.UserContent] | None, *, output_type: None = None, message_history: list[_messages.ModelMessage] | None = None, model: models.Model | models.KnownModelName | str | None = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.Usage | None = None, infer_name: bool = True, **_deprecated_kwargs: Never, ) -> AbstractAsyncContextManager[AgentRun[AgentDepsT, OutputDataT]]: ... @overload def iter( self, user_prompt: str | Sequence[_messages.UserContent] | None, *, output_type: OutputSpec[RunOutputDataT], message_history: list[_messages.ModelMessage] | None = None, model: models.Model | models.KnownModelName | str | None = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.Usage | None = None, infer_name: bool = True, **_deprecated_kwargs: Never, ) -> AbstractAsyncContextManager[AgentRun[AgentDepsT, RunOutputDataT]]: ... @overload @deprecated('`result_type` is deprecated, use `output_type` instead.') def iter( self, user_prompt: str | Sequence[_messages.UserContent] | None, *, result_type: type[RunOutputDataT], message_history: list[_messages.ModelMessage] | None = None, model: models.Model | models.KnownModelName | str | None = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.Usage | None = None, infer_name: bool = True, ) -> AbstractAsyncContextManager[AgentRun[AgentDepsT, Any]]: ... @asynccontextmanager async def iter( self, user_prompt: str | Sequence[_messages.UserContent] | None = None, *, output_type: OutputSpec[RunOutputDataT] | None = None, message_history: list[_messages.ModelMessage] | None = None, model: models.Model | models.KnownModelName | str | None = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.Usage | None = None, infer_name: bool = True, **_deprecated_kwargs: Never, ) -> AsyncIterator[AgentRun[AgentDepsT, Any]]: """A contextmanager which can be used to iterate over the agent graph's nodes as they are executed. This method builds an internal agent graph (using system prompts, tools and output schemas) and then returns an `AgentRun` object. The `AgentRun` can be used to async-iterate over the nodes of the graph as they are executed. This is the API to use if you want to consume the outputs coming from each LLM model response, or the stream of events coming from the execution of tools. The `AgentRun` also provides methods to access the full message history, new messages, and usage statistics, and the final result of the run once it has completed. For more details, see the documentation of `AgentRun`. Example: ```python from pydantic_ai import Agent agent = Agent('openai:gpt-4o') async def main(): nodes = [] async with agent.iter('What is the capital of France?') as agent_run: async for node in agent_run: nodes.append(node) print(nodes) ''' [ UserPromptNode( user_prompt='What is the capital of France?', instructions=None, instructions_functions=[], system_prompts=(), system_prompt_functions=[], system_prompt_dynamic_functions={}, ), ModelRequestNode( request=ModelRequest( parts=[ UserPromptPart( content='What is the capital of France?', timestamp=datetime.datetime(...), ) ] ) ), CallToolsNode( model_response=ModelResponse( parts=[TextPart(content='Paris')], usage=Usage( requests=1, request_tokens=56, response_tokens=1, total_tokens=57 ), model_name='gpt-4o', timestamp=datetime.datetime(...), ) ), End(data=FinalResult(output='Paris')), ] ''' print(agent_run.result.output) #> Paris ``` Args: user_prompt: User input to start/continue the conversation. output_type: Custom output type to use for this run, `output_type` may only be used if the agent has no output validators since output validators would expect an argument that matches the agent's output type. message_history: History of the conversation so far. model: Optional model to use for this run, required if `model` was not set when creating the agent. deps: Optional dependencies to use for this run. model_settings: Optional settings to use for this model's request. usage_limits: Optional limits on model request count or token usage. usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. infer_name: Whether to try to infer the agent name from the call frame if it's not set. Returns: The result of the run. """ if infer_name and self.name is None: self._infer_name(inspect.currentframe()) model_used = self._get_model(model) del model if 'result_type' in _deprecated_kwargs: # pragma: no cover if output_type is not str: raise TypeError('`result_type` and `output_type` cannot be set at the same time.') warnings.warn('`result_type` is deprecated, use `output_type` instead.', DeprecationWarning, stacklevel=2) output_type = _deprecated_kwargs.pop('result_type') _utils.validate_empty_kwargs(_deprecated_kwargs) deps = self._get_deps(deps) new_message_index = len(message_history) if message_history else 0 output_schema = self._prepare_output_schema(output_type, model_used.profile) output_type_ = output_type or self.output_type # Build the graph graph: Graph[_agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any], FinalResult[Any]] = ( _agent_graph.build_agent_graph(self.name, self._deps_type, output_type_) ) # Build the initial state usage = usage or _usage.Usage() state = _agent_graph.GraphAgentState( message_history=message_history[:] if message_history else [], usage=usage, retries=0, run_step=0, ) # We consider it a user error if a user tries to restrict the result type while having an output validator that # may change the result type from the restricted type to something else. Therefore, we consider the following # typecast reasonable, even though it is possible to violate it with otherwise-type-checked code. output_validators = cast(list[_output.OutputValidator[AgentDepsT, RunOutputDataT]], self._output_validators) model_settings = merge_model_settings(self.model_settings, model_settings) usage_limits = usage_limits or _usage.UsageLimits() if isinstance(model_used, InstrumentedModel): instrumentation_settings = model_used.settings tracer = model_used.settings.tracer else: instrumentation_settings = None tracer = NoOpTracer() agent_name = self.name or 'agent' run_span = tracer.start_span( 'agent run', attributes={ 'model_name': model_used.model_name if model_used else 'no-model', 'agent_name': agent_name, 'logfire.msg': f'{agent_name} run', }, ) async def get_instructions(run_context: RunContext[AgentDepsT]) -> str | None: parts = [ self._instructions, *[await func.run(run_context) for func in self._instructions_functions], ] model_profile = model_used.profile if isinstance(output_schema, _output.PromptedOutputSchema): instructions = output_schema.instructions(model_profile.prompted_output_template) parts.append(instructions) parts = [p for p in parts if p] if not parts: return None return '\n\n'.join(parts).strip() # Copy the function tools so that retry state is agent-run-specific # Note that the retry count is reset to 0 when this happens due to the `default=0` and `init=False`. run_function_tools = {k: dataclasses.replace(v) for k, v in self._function_tools.items()} graph_deps = _agent_graph.GraphAgentDeps[AgentDepsT, RunOutputDataT]( user_deps=deps, prompt=user_prompt, new_message_index=new_message_index, model=model_used, model_settings=model_settings, usage_limits=usage_limits, max_result_retries=self._max_result_retries, end_strategy=self.end_strategy, output_schema=output_schema, output_validators=output_validators, history_processors=self.history_processors, function_tools=run_function_tools, mcp_servers=self._mcp_servers, default_retries=self._default_retries, tracer=tracer, prepare_tools=self._prepare_tools, get_instructions=get_instructions, instrumentation_settings=instrumentation_settings, ) start_node = _agent_graph.UserPromptNode[AgentDepsT]( user_prompt=user_prompt, instructions=self._instructions, instructions_functions=self._instructions_functions, system_prompts=self._system_prompts, system_prompt_functions=self._system_prompt_functions, system_prompt_dynamic_functions=self._system_prompt_dynamic_functions, ) try: async with graph.iter( start_node, state=state, deps=graph_deps, span=use_span(run_span) if run_span.is_recording() else None, infer_name=False, ) as graph_run: agent_run = AgentRun(graph_run) yield agent_run if (final_result := agent_run.result) is not None and run_span.is_recording(): run_span.set_attribute( 'final_result', ( final_result.output if isinstance(final_result.output, str) else json.dumps(InstrumentedModel.serialize_any(final_result.output)) ), ) finally: try: if instrumentation_settings and run_span.is_recording(): run_span.set_attributes(self._run_span_end_attributes(state, usage, instrumentation_settings)) finally: run_span.end() def _run_span_end_attributes( self, state: _agent_graph.GraphAgentState, usage: _usage.Usage, settings: InstrumentationSettings ): return { **usage.opentelemetry_attributes(), 'all_messages_events': json.dumps( [InstrumentedModel.event_to_dict(e) for e in settings.messages_to_otel_events(state.message_history)] ), 'logfire.json_schema': json.dumps( { 'type': 'object', 'properties': { 'all_messages_events': {'type': 'array'}, 'final_result': {'type': 'object'}, }, } ), } @overload def run_sync( self, user_prompt: str | Sequence[_messages.UserContent] | None = None, *, message_history: list[_messages.ModelMessage] | None = None, model: models.Model | models.KnownModelName | str | None = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.Usage | None = None, infer_name: bool = True, ) -> AgentRunResult[OutputDataT]: ... @overload def run_sync( self, user_prompt: str | Sequence[_messages.UserContent] | None = None, *, output_type: OutputSpec[RunOutputDataT] | None = None, message_history: list[_messages.ModelMessage] | None = None, model: models.Model | models.KnownModelName | str | None = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.Usage | None = None, infer_name: bool = True, ) -> AgentRunResult[RunOutputDataT]: ... @overload @deprecated('`result_type` is deprecated, use `output_type` instead.') def run_sync( self, user_prompt: str | Sequence[_messages.UserContent] | None = None, *, result_type: type[RunOutputDataT], message_history: list[_messages.ModelMessage] | None = None, model: models.Model | models.KnownModelName | str | None = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.Usage | None = None, infer_name: bool = True, ) -> AgentRunResult[RunOutputDataT]: ... def run_sync( self, user_prompt: str | Sequence[_messages.UserContent] | None = None, *, output_type: OutputSpec[RunOutputDataT] | None = None, message_history: list[_messages.ModelMessage] | None = None, model: models.Model | models.KnownModelName | str | None = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.Usage | None = None, infer_name: bool = True, **_deprecated_kwargs: Never, ) -> AgentRunResult[Any]: """Synchronously run the agent with a user prompt. This is a convenience method that wraps [`self.run`][pydantic_ai.Agent.run] with `loop.run_until_complete(...)`. You therefore can't use this method inside async code or if there's an active event loop. Example: ```python from pydantic_ai import Agent agent = Agent('openai:gpt-4o') result_sync = agent.run_sync('What is the capital of Italy?') print(result_sync.output) #> Rome ``` Args: user_prompt: User input to start/continue the conversation. output_type: Custom output type to use for this run, `output_type` may only be used if the agent has no output validators since output validators would expect an argument that matches the agent's output type. message_history: History of the conversation so far. model: Optional model to use for this run, required if `model` was not set when creating the agent. deps: Optional dependencies to use for this run. model_settings: Optional settings to use for this model's request. usage_limits: Optional limits on model request count or token usage. usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. infer_name: Whether to try to infer the agent name from the call frame if it's not set. Returns: The result of the run. """ if infer_name and self.name is None: self._infer_name(inspect.currentframe()) if 'result_type' in _deprecated_kwargs: # pragma: no cover if output_type is not str: raise TypeError('`result_type` and `output_type` cannot be set at the same time.') warnings.warn('`result_type` is deprecated, use `output_type` instead.', DeprecationWarning, stacklevel=2) output_type = _deprecated_kwargs.pop('result_type') _utils.validate_empty_kwargs(_deprecated_kwargs) return get_event_loop().run_until_complete( self.run( user_prompt, output_type=output_type, message_history=message_history, model=model, deps=deps, model_settings=model_settings, usage_limits=usage_limits, usage=usage, infer_name=False, ) ) @overload def run_stream( self, user_prompt: str | Sequence[_messages.UserContent] | None = None, *, message_history: list[_messages.ModelMessage] | None = None, model: models.Model | models.KnownModelName | None = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.Usage | None = None, infer_name: bool = True, ) -> AbstractAsyncContextManager[result.StreamedRunResult[AgentDepsT, OutputDataT]]: ... @overload def run_stream( self, user_prompt: str | Sequence[_messages.UserContent], *, output_type: OutputSpec[RunOutputDataT], message_history: list[_messages.ModelMessage] | None = None, model: models.Model | models.KnownModelName | str | None = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.Usage | None = None, infer_name: bool = True, ) -> AbstractAsyncContextManager[result.StreamedRunResult[AgentDepsT, RunOutputDataT]]: ... @overload @deprecated('`result_type` is deprecated, use `output_type` instead.') def run_stream( self, user_prompt: str | Sequence[_messages.UserContent] | None = None, *, result_type: type[RunOutputDataT], message_history: list[_messages.ModelMessage] | None = None, model: models.Model | models.KnownModelName | str | None = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.Usage | None = None, infer_name: bool = True, ) -> AbstractAsyncContextManager[result.StreamedRunResult[AgentDepsT, RunOutputDataT]]: ... @asynccontextmanager async def run_stream( # noqa C901 self, user_prompt: str | Sequence[_messages.UserContent] | None = None, *, output_type: OutputSpec[RunOutputDataT] | None = None, message_history: list[_messages.ModelMessage] | None = None, model: models.Model | models.KnownModelName | str | None = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.Usage | None = None, infer_name: bool = True, **_deprecated_kwargs: Never, ) -> AsyncIterator[result.StreamedRunResult[AgentDepsT, Any]]: """Run the agent with a user prompt in async mode, returning a streamed response. Example: ```python from pydantic_ai import Agent agent = Agent('openai:gpt-4o') async def main(): async with agent.run_stream('What is the capital of the UK?') as response: print(await response.get_output()) #> London ``` Args: user_prompt: User input to start/continue the conversation. output_type: Custom output type to use for this run, `output_type` may only be used if the agent has no output validators since output validators would expect an argument that matches the agent's output type. message_history: History of the conversation so far. model: Optional model to use for this run, required if `model` was not set when creating the agent. deps: Optional dependencies to use for this run. model_settings: Optional settings to use for this model's request. usage_limits: Optional limits on model request count or token usage. usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. infer_name: Whether to try to infer the agent name from the call frame if it's not set. Returns: The result of the run. """ # TODO: We need to deprecate this now that we have the `iter` method. # Before that, though, we should add an event for when we reach the final result of the stream. if infer_name and self.name is None: # f_back because `asynccontextmanager` adds one frame if frame := inspect.currentframe(): # pragma: no branch self._infer_name(frame.f_back) if 'result_type' in _deprecated_kwargs: # pragma: no cover if output_type is not str: raise TypeError('`result_type` and `output_type` cannot be set at the same time.') warnings.warn('`result_type` is deprecated, use `output_type` instead.', DeprecationWarning, stacklevel=2) output_type = _deprecated_kwargs.pop('result_type') _utils.validate_empty_kwargs(_deprecated_kwargs) yielded = False async with self.iter( user_prompt, output_type=output_type, message_history=message_history, model=model, deps=deps, model_settings=model_settings, usage_limits=usage_limits, usage=usage, infer_name=False, ) as agent_run: first_node = agent_run.next_node # start with the first node assert isinstance(first_node, _agent_graph.UserPromptNode) # the first node should be a user prompt node node = first_node while True: if self.is_model_request_node(node): graph_ctx = agent_run.ctx async with node._stream(graph_ctx) as streamed_response: # pyright: ignore[reportPrivateUsage] async def stream_to_final( s: models.StreamedResponse, ) -> FinalResult[models.StreamedResponse] | None: output_schema = graph_ctx.deps.output_schema async for maybe_part_event in streamed_response: if isinstance(maybe_part_event, _messages.PartStartEvent): new_part = maybe_part_event.part if isinstance(new_part, _messages.TextPart) and isinstance( output_schema, _output.TextOutputSchema ): return FinalResult(s, None, None) elif isinstance(new_part, _messages.ToolCallPart) and isinstance( output_schema, _output.ToolOutputSchema ): # pragma: no branch for call, _ in output_schema.find_tool([new_part]): return FinalResult(s, call.tool_name, call.tool_call_id) return None final_result_details = await stream_to_final(streamed_response) if final_result_details is not None: if yielded: raise exceptions.AgentRunError('Agent run produced final results') # pragma: no cover yielded = True messages = graph_ctx.state.message_history.copy() async def on_complete() -> None: """Called when the stream has completed. The model response will have been added to messages by now by `StreamedRunResult._marked_completed`. """ last_message = messages[-1] assert isinstance(last_message, _messages.ModelResponse) tool_calls = [ part for part in last_message.parts if isinstance(part, _messages.ToolCallPart) ] parts: list[_messages.ModelRequestPart] = [] async for _event in _agent_graph.process_function_tools( tool_calls, final_result_details.tool_name, final_result_details.tool_call_id, graph_ctx, parts, ): pass # TODO: Should we do something here related to the retry count? # Maybe we should move the incrementing of the retry count to where we actually make a request? # if any(isinstance(part, _messages.RetryPromptPart) for part in parts): # ctx.state.increment_retries(ctx.deps.max_result_retries) if parts: messages.append(_messages.ModelRequest(parts)) yield StreamedRunResult( messages, graph_ctx.deps.new_message_index, graph_ctx.deps.usage_limits, streamed_response, graph_ctx.deps.output_schema, _agent_graph.build_run_context(graph_ctx), graph_ctx.deps.output_validators, final_result_details.tool_name, on_complete, ) break next_node = await agent_run.next(node) if not isinstance(next_node, _agent_graph.AgentNode): raise exceptions.AgentRunError( # pragma: no cover 'Should have produced a StreamedRunResult before getting here' ) node = cast(_agent_graph.AgentNode[Any, Any], next_node) if not yielded: raise exceptions.AgentRunError('Agent run finished without producing a final result') # pragma: no cover @contextmanager def override( self, *, deps: AgentDepsT | _utils.Unset = _utils.UNSET, model: models.Model | models.KnownModelName | str | _utils.Unset = _utils.UNSET, ) -> Iterator[None]: """Context manager to temporarily override agent dependencies and model. This is particularly useful when testing. You can find an example of this [here](../testing.md#overriding-model-via-pytest-fixtures). Args: deps: The dependencies to use instead of the dependencies passed to the agent run. model: The model to use instead of the model passed to the agent run. """ if _utils.is_set(deps): deps_token = self._override_deps.set(_utils.Some(deps)) else: deps_token = None if _utils.is_set(model): model_token = self._override_model.set(_utils.Some(models.infer_model(model))) else: model_token = None try: yield finally: if deps_token is not None: self._override_deps.reset(deps_token) if model_token is not None: self._override_model.reset(model_token) @overload def instructions( self, func: Callable[[RunContext[AgentDepsT]], str], / ) -> Callable[[RunContext[AgentDepsT]], str]: ... @overload def instructions( self, func: Callable[[RunContext[AgentDepsT]], Awaitable[str]], / ) -> Callable[[RunContext[AgentDepsT]], Awaitable[str]]: ... @overload def instructions(self, func: Callable[[], str], /) -> Callable[[], str]: ... @overload def instructions(self, func: Callable[[], Awaitable[str]], /) -> Callable[[], Awaitable[str]]: ... @overload def instructions( self, / ) -> Callable[[_system_prompt.SystemPromptFunc[AgentDepsT]], _system_prompt.SystemPromptFunc[AgentDepsT]]: ... def instructions( self, func: _system_prompt.SystemPromptFunc[AgentDepsT] | None = None, /, ) -> ( Callable[[_system_prompt.SystemPromptFunc[AgentDepsT]], _system_prompt.SystemPromptFunc[AgentDepsT]] | _system_prompt.SystemPromptFunc[AgentDepsT] ): """Decorator to register an instructions function. Optionally takes [`RunContext`][pydantic_ai.tools.RunContext] as its only argument. Can decorate a sync or async functions. The decorator can be used bare (`agent.instructions`). Overloads for every possible signature of `instructions` are included so the decorator doesn't obscure the type of the function. Example: ```python from pydantic_ai import Agent, RunContext agent = Agent('test', deps_type=str) @agent.instructions def simple_instructions() -> str: return 'foobar' @agent.instructions async def async_instructions(ctx: RunContext[str]) -> str: return f'{ctx.deps} is the best' ``` """ if func is None: def decorator( func_: _system_prompt.SystemPromptFunc[AgentDepsT], ) -> _system_prompt.SystemPromptFunc[AgentDepsT]: self._instructions_functions.append(_system_prompt.SystemPromptRunner(func_)) return func_ return decorator else: self._instructions_functions.append(_system_prompt.SystemPromptRunner(func)) return func @overload def system_prompt( self, func: Callable[[RunContext[AgentDepsT]], str], / ) -> Callable[[RunContext[AgentDepsT]], str]: ... @overload def system_prompt( self, func: Callable[[RunContext[AgentDepsT]], Awaitable[str]], / ) -> Callable[[RunContext[AgentDepsT]], Awaitable[str]]: ... @overload def system_prompt(self, func: Callable[[], str], /) -> Callable[[], str]: ... @overload def system_prompt(self, func: Callable[[], Awaitable[str]], /) -> Callable[[], Awaitable[str]]: ... @overload def system_prompt( self, /, *, dynamic: bool = False ) -> Callable[[_system_prompt.SystemPromptFunc[AgentDepsT]], _system_prompt.SystemPromptFunc[AgentDepsT]]: ... def system_prompt( self, func: _system_prompt.SystemPromptFunc[AgentDepsT] | None = None, /, *, dynamic: bool = False, ) -> ( Callable[[_system_prompt.SystemPromptFunc[AgentDepsT]], _system_prompt.SystemPromptFunc[AgentDepsT]] | _system_prompt.SystemPromptFunc[AgentDepsT] ): """Decorator to register a system prompt function. Optionally takes [`RunContext`][pydantic_ai.tools.RunContext] as its only argument. Can decorate a sync or async functions. The decorator can be used either bare (`agent.system_prompt`) or as a function call (`agent.system_prompt(...)`), see the examples below. Overloads for every possible signature of `system_prompt` are included so the decorator doesn't obscure the type of the function, see `tests/typed_agent.py` for tests. Args: func: The function to decorate dynamic: If True, the system prompt will be reevaluated even when `messages_history` is provided, see [`SystemPromptPart.dynamic_ref`][pydantic_ai.messages.SystemPromptPart.dynamic_ref] Example: ```python from pydantic_ai import Agent, RunContext agent = Agent('test', deps_type=str) @agent.system_prompt def simple_system_prompt() -> str: return 'foobar' @agent.system_prompt(dynamic=True) async def async_system_prompt(ctx: RunContext[str]) -> str: return f'{ctx.deps} is the best' ``` """ if func is None: def decorator( func_: _system_prompt.SystemPromptFunc[AgentDepsT], ) -> _system_prompt.SystemPromptFunc[AgentDepsT]: runner = _system_prompt.SystemPromptRunner[AgentDepsT](func_, dynamic=dynamic) self._system_prompt_functions.append(runner) if dynamic: # pragma: lax no cover self._system_prompt_dynamic_functions[func_.__qualname__] = runner return func_ return decorator else: assert not dynamic, "dynamic can't be True in this case" self._system_prompt_functions.append(_system_prompt.SystemPromptRunner[AgentDepsT](func, dynamic=dynamic)) return func @overload def output_validator( self, func: Callable[[RunContext[AgentDepsT], OutputDataT], OutputDataT], / ) -> Callable[[RunContext[AgentDepsT], OutputDataT], OutputDataT]: ... @overload def output_validator( self, func: Callable[[RunContext[AgentDepsT], OutputDataT], Awaitable[OutputDataT]], / ) -> Callable[[RunContext[AgentDepsT], OutputDataT], Awaitable[OutputDataT]]: ... @overload def output_validator( self, func: Callable[[OutputDataT], OutputDataT], / ) -> Callable[[OutputDataT], OutputDataT]: ... @overload def output_validator( self, func: Callable[[OutputDataT], Awaitable[OutputDataT]], / ) -> Callable[[OutputDataT], Awaitable[OutputDataT]]: ... def output_validator( self, func: _output.OutputValidatorFunc[AgentDepsT, OutputDataT], / ) -> _output.OutputValidatorFunc[AgentDepsT, OutputDataT]: """Decorator to register an output validator function. Optionally takes [`RunContext`][pydantic_ai.tools.RunContext] as its first argument. Can decorate a sync or async functions. Overloads for every possible signature of `output_validator` are included so the decorator doesn't obscure the type of the function, see `tests/typed_agent.py` for tests. Example: ```python from pydantic_ai import Agent, ModelRetry, RunContext agent = Agent('test', deps_type=str) @agent.output_validator def output_validator_simple(data: str) -> str: if 'wrong' in data: raise ModelRetry('wrong response') return data @agent.output_validator async def output_validator_deps(ctx: RunContext[str], data: str) -> str: if ctx.deps in data: raise ModelRetry('wrong response') return data result = agent.run_sync('foobar', deps='spam') print(result.output) #> success (no tool calls) ``` """ self._output_validators.append(_output.OutputValidator[AgentDepsT, Any](func)) return func @deprecated('`result_validator` is deprecated, use `output_validator` instead.') def result_validator(self, func: Any, /) -> Any: warnings.warn( '`result_validator` is deprecated, use `output_validator` instead.', DeprecationWarning, stacklevel=2 ) return self.output_validator(func) # type: ignore @overload def tool(self, func: ToolFuncContext[AgentDepsT, ToolParams], /) -> ToolFuncContext[AgentDepsT, ToolParams]: ... @overload def tool( self, /, *, name: str | None = None, retries: int | None = None, prepare: ToolPrepareFunc[AgentDepsT] | None = None, docstring_format: DocstringFormat = 'auto', require_parameter_descriptions: bool = False, schema_generator: type[GenerateJsonSchema] = GenerateToolJsonSchema, strict: bool | None = None, ) -> Callable[[ToolFuncContext[AgentDepsT, ToolParams]], ToolFuncContext[AgentDepsT, ToolParams]]: ... def tool( self, func: ToolFuncContext[AgentDepsT, ToolParams] | None = None, /, *, name: str | None = None, retries: int | None = None, prepare: ToolPrepareFunc[AgentDepsT] | None = None, docstring_format: DocstringFormat = 'auto', require_parameter_descriptions: bool = False, schema_generator: type[GenerateJsonSchema] = GenerateToolJsonSchema, strict: bool | None = None, ) -> Any: """Decorator to register a tool function which takes [`RunContext`][pydantic_ai.tools.RunContext] as its first argument. Can decorate a sync or async functions. The docstring is inspected to extract both the tool description and description of each parameter, [learn more](../tools.md#function-tools-and-schema). We can't add overloads for every possible signature of tool, since the return type is a recursive union so the signature of functions decorated with `@agent.tool` is obscured. Example: ```python from pydantic_ai import Agent, RunContext agent = Agent('test', deps_type=int) @agent.tool def foobar(ctx: RunContext[int], x: int) -> int: return ctx.deps + x @agent.tool(retries=2) async def spam(ctx: RunContext[str], y: float) -> float: return ctx.deps + y result = agent.run_sync('foobar', deps=1) print(result.output) #> {"foobar":1,"spam":1.0} ``` Args: func: The tool function to register. name: The name of the tool, defaults to the function name. retries: The number of retries to allow for this tool, defaults to the agent's default retries, which defaults to 1. prepare: custom method to prepare the tool definition for each step, return `None` to omit this tool from a given step. This is useful if you want to customise a tool at call time, or omit it completely from a step. See [`ToolPrepareFunc`][pydantic_ai.tools.ToolPrepareFunc]. docstring_format: The format of the docstring, see [`DocstringFormat`][pydantic_ai.tools.DocstringFormat]. Defaults to `'auto'`, such that the format is inferred from the structure of the docstring. require_parameter_descriptions: If True, raise an error if a parameter description is missing. Defaults to False. schema_generator: The JSON schema generator class to use for this tool. Defaults to `GenerateToolJsonSchema`. strict: Whether to enforce JSON schema compliance (only affects OpenAI). See [`ToolDefinition`][pydantic_ai.tools.ToolDefinition] for more info. """ if func is None: def tool_decorator( func_: ToolFuncContext[AgentDepsT, ToolParams], ) -> ToolFuncContext[AgentDepsT, ToolParams]: # noinspection PyTypeChecker self._register_function( func_, True, name, retries, prepare, docstring_format, require_parameter_descriptions, schema_generator, strict, ) return func_ return tool_decorator else: # noinspection PyTypeChecker self._register_function( func, True, name, retries, prepare, docstring_format, require_parameter_descriptions, schema_generator, strict, ) return func @overload def tool_plain(self, func: ToolFuncPlain[ToolParams], /) -> ToolFuncPlain[ToolParams]: ... @overload def tool_plain( self, /, *, name: str | None = None, retries: int | None = None, prepare: ToolPrepareFunc[AgentDepsT] | None = None, docstring_format: DocstringFormat = 'auto', require_parameter_descriptions: bool = False, schema_generator: type[GenerateJsonSchema] = GenerateToolJsonSchema, strict: bool | None = None, ) -> Callable[[ToolFuncPlain[ToolParams]], ToolFuncPlain[ToolParams]]: ... def tool_plain( self, func: ToolFuncPlain[ToolParams] | None = None, /, *, name: str | None = None, retries: int | None = None, prepare: ToolPrepareFunc[AgentDepsT] | None = None, docstring_format: DocstringFormat = 'auto', require_parameter_descriptions: bool = False, schema_generator: type[GenerateJsonSchema] = GenerateToolJsonSchema, strict: bool | None = None, ) -> Any: """Decorator to register a tool function which DOES NOT take `RunContext` as an argument. Can decorate a sync or async functions. The docstring is inspected to extract both the tool description and description of each parameter, [learn more](../tools.md#function-tools-and-schema). We can't add overloads for every possible signature of tool, since the return type is a recursive union so the signature of functions decorated with `@agent.tool` is obscured. Example: ```python from pydantic_ai import Agent, RunContext agent = Agent('test') @agent.tool def foobar(ctx: RunContext[int]) -> int: return 123 @agent.tool(retries=2) async def spam(ctx: RunContext[str]) -> float: return 3.14 result = agent.run_sync('foobar', deps=1) print(result.output) #> {"foobar":123,"spam":3.14} ``` Args: func: The tool function to register. name: The name of the tool, defaults to the function name. retries: The number of retries to allow for this tool, defaults to the agent's default retries, which defaults to 1. prepare: custom method to prepare the tool definition for each step, return `None` to omit this tool from a given step. This is useful if you want to customise a tool at call time, or omit it completely from a step. See [`ToolPrepareFunc`][pydantic_ai.tools.ToolPrepareFunc]. docstring_format: The format of the docstring, see [`DocstringFormat`][pydantic_ai.tools.DocstringFormat]. Defaults to `'auto'`, such that the format is inferred from the structure of the docstring. require_parameter_descriptions: If True, raise an error if a parameter description is missing. Defaults to False. schema_generator: The JSON schema generator class to use for this tool. Defaults to `GenerateToolJsonSchema`. strict: Whether to enforce JSON schema compliance (only affects OpenAI). See [`ToolDefinition`][pydantic_ai.tools.ToolDefinition] for more info. """ if func is None: def tool_decorator(func_: ToolFuncPlain[ToolParams]) -> ToolFuncPlain[ToolParams]: # noinspection PyTypeChecker self._register_function( func_, False, name, retries, prepare, docstring_format, require_parameter_descriptions, schema_generator, strict, ) return func_ return tool_decorator else: self._register_function( func, False, name, retries, prepare, docstring_format, require_parameter_descriptions, schema_generator, strict, ) return func def _register_function( self, func: ToolFuncEither[AgentDepsT, ToolParams], takes_ctx: bool, name: str | None, retries: int | None, prepare: ToolPrepareFunc[AgentDepsT] | None, docstring_format: DocstringFormat, require_parameter_descriptions: bool, schema_generator: type[GenerateJsonSchema], strict: bool | None, ) -> None: """Private utility to register a function as a tool.""" retries_ = retries if retries is not None else self._default_retries tool = Tool[AgentDepsT]( func, takes_ctx=takes_ctx, name=name, max_retries=retries_, prepare=prepare, docstring_format=docstring_format, require_parameter_descriptions=require_parameter_descriptions, schema_generator=schema_generator, strict=strict, ) self._register_tool(tool) def _register_tool(self, tool: Tool[AgentDepsT]) -> None: """Private utility to register a tool instance.""" if tool.max_retries is None: # noinspection PyTypeChecker tool = dataclasses.replace(tool, max_retries=self._default_retries) if tool.name in self._function_tools: raise exceptions.UserError(f'Tool name conflicts with existing tool: {tool.name!r}') if tool.name in self._output_schema.tools: raise exceptions.UserError(f'Tool name conflicts with output tool name: {tool.name!r}') self._function_tools[tool.name] = tool def _get_model(self, model: models.Model | models.KnownModelName | str | None) -> models.Model: """Create a model configured for this agent. Args: model: model to use for this run, required if `model` was not set when creating the agent. Returns: The model used """ model_: models.Model if some_model := self._override_model.get(): # we don't want `override()` to cover up errors from the model not being defined, hence this check if model is None and self.model is None: raise exceptions.UserError( '`model` must either be set on the agent or included when calling it. ' '(Even when `override(model=...)` is customizing the model that will actually be called)' ) model_ = some_model.value elif model is not None: model_ = models.infer_model(model) elif self.model is not None: # noinspection PyTypeChecker model_ = self.model = models.infer_model(self.model) else: raise exceptions.UserError('`model` must either be set on the agent or included when calling it.') instrument = self.instrument if instrument is None: instrument = self._instrument_default return instrument_model(model_, instrument) def _get_deps(self: Agent[T, OutputDataT], deps: T) -> T: """Get deps for a run. If we've overridden deps via `_override_deps`, use that, otherwise use the deps passed to the call. We could do runtime type checking of deps against `self._deps_type`, but that's a slippery slope. """ if some_deps := self._override_deps.get(): return some_deps.value else: return deps def _infer_name(self, function_frame: FrameType | None) -> None: """Infer the agent name from the call frame. Usage should be `self._infer_name(inspect.currentframe())`. """ assert self.name is None, 'Name already set' if function_frame is not None: # pragma: no branch if parent_frame := function_frame.f_back: # pragma: no branch for name, item in parent_frame.f_locals.items(): if item is self: self.name = name return if parent_frame.f_locals != parent_frame.f_globals: # pragma: no branch # if we couldn't find the agent in locals and globals are a different dict, try globals for name, item in parent_frame.f_globals.items(): if item is self: self.name = name return @property @deprecated( 'The `last_run_messages` attribute has been removed, use `capture_run_messages` instead.', category=None ) def last_run_messages(self) -> list[_messages.ModelMessage]: raise AttributeError('The `last_run_messages` attribute has been removed, use `capture_run_messages` instead.') def _prepare_output_schema( self, output_type: OutputSpec[RunOutputDataT] | None, model_profile: ModelProfile ) -> _output.OutputSchema[RunOutputDataT]: if output_type is not None: if self._output_validators: raise exceptions.UserError('Cannot set a custom run `output_type` when the agent has output validators') schema = _output.OutputSchema[RunOutputDataT].build( output_type, name=self._deprecated_result_tool_name, description=self._deprecated_result_tool_description, default_mode=model_profile.default_structured_output_mode, ) else: schema = self._output_schema.with_default_mode(model_profile.default_structured_output_mode) schema.raise_if_unsupported(model_profile) return schema # pyright: ignore[reportReturnType] @staticmethod def is_model_request_node( node: _agent_graph.AgentNode[T, S] | End[result.FinalResult[S]], ) -> TypeIs[_agent_graph.ModelRequestNode[T, S]]: """Check if the node is a `ModelRequestNode`, narrowing the type if it is. This method preserves the generic parameters while narrowing the type, unlike a direct call to `isinstance`. """ return isinstance(node, _agent_graph.ModelRequestNode) @staticmethod def is_call_tools_node( node: _agent_graph.AgentNode[T, S] | End[result.FinalResult[S]], ) -> TypeIs[_agent_graph.CallToolsNode[T, S]]: """Check if the node is a `CallToolsNode`, narrowing the type if it is. This method preserves the generic parameters while narrowing the type, unlike a direct call to `isinstance`. """ return isinstance(node, _agent_graph.CallToolsNode) @staticmethod def is_user_prompt_node( node: _agent_graph.AgentNode[T, S] | End[result.FinalResult[S]], ) -> TypeIs[_agent_graph.UserPromptNode[T, S]]: """Check if the node is a `UserPromptNode`, narrowing the type if it is. This method preserves the generic parameters while narrowing the type, unlike a direct call to `isinstance`. """ return isinstance(node, _agent_graph.UserPromptNode) @staticmethod def is_end_node( node: _agent_graph.AgentNode[T, S] | End[result.FinalResult[S]], ) -> TypeIs[End[result.FinalResult[S]]]: """Check if the node is a `End`, narrowing the type if it is. This method preserves the generic parameters while narrowing the type, unlike a direct call to `isinstance`. """ return isinstance(node, End) @asynccontextmanager async def run_mcp_servers( self, model: models.Model | models.KnownModelName | str | None = None ) -> AsyncIterator[None]: """Run [`MCPServerStdio`s][pydantic_ai.mcp.MCPServerStdio] so they can be used by the agent. Returns: a context manager to start and shutdown the servers. """ try: sampling_model: models.Model | None = self._get_model(model) except exceptions.UserError: # pragma: no cover sampling_model = None exit_stack = AsyncExitStack() try: for mcp_server in self._mcp_servers: if sampling_model is not None: # pragma: no branch mcp_server.sampling_model = sampling_model await exit_stack.enter_async_context(mcp_server) yield finally: await exit_stack.aclose() def to_a2a( self, *, storage: Storage | None = None, broker: Broker | None = None, # Agent card name: str | None = None, url: str = 'http://localhost:8000', version: str = '1.0.0', description: str | None = None, provider: Provider | None = None, skills: list[Skill] | None = None, # Starlette debug: bool = False, routes: Sequence[Route] | None = None, middleware: Sequence[Middleware] | None = None, exception_handlers: dict[Any, ExceptionHandler] | None = None, lifespan: Lifespan[FastA2A] | None = None, ) -> FastA2A: """Convert the agent to a FastA2A application. Example: ```python from pydantic_ai import Agent agent = Agent('openai:gpt-4o') app = agent.to_a2a() ``` The `app` is an ASGI application that can be used with any ASGI server. To run the application, you can use the following command: ```bash uvicorn app:app --host 0.0.0.0 --port 8000 ``` """ from ._a2a import agent_to_a2a return agent_to_a2a( self, storage=storage, broker=broker, name=name, url=url, version=version, description=description, provider=provider, skills=skills, debug=debug, routes=routes, middleware=middleware, exception_handlers=exception_handlers, lifespan=lifespan, ) async def to_cli(self: Self, deps: AgentDepsT = None, prog_name: str = 'pydantic-ai') -> None: """Run the agent in a CLI chat interface. Args: deps: The dependencies to pass to the agent. prog_name: The name of the program to use for the CLI. Defaults to 'pydantic-ai'. Example: ```python {title="agent_to_cli.py" test="skip"} from pydantic_ai import Agent agent = Agent('openai:gpt-4o', instructions='You always respond in Italian.') async def main(): await agent.to_cli() ``` """ from rich.console import Console from pydantic_ai._cli import run_chat await run_chat(stream=True, agent=self, deps=deps, console=Console(), code_theme='monokai', prog_name=prog_name) def to_cli_sync(self: Self, deps: AgentDepsT = None, prog_name: str = 'pydantic-ai') -> None: """Run the agent in a CLI chat interface with the non-async interface. Args: deps: The dependencies to pass to the agent. prog_name: The name of the program to use for the CLI. Defaults to 'pydantic-ai'. ```python {title="agent_to_cli_sync.py" test="skip"} from pydantic_ai import Agent agent = Agent('openai:gpt-4o', instructions='You always respond in Italian.') agent.to_cli_sync() agent.to_cli_sync(prog_name='assistant') ``` """ return get_event_loop().run_until_complete(self.to_cli(deps=deps, prog_name=prog_name)) ```` #### model ```python model: Model | KnownModelName | str | None ``` The default model configured for this agent. We allow `str` here since the actual list of allowed models changes frequently. #### __init__ ```python __init__( model: Model | KnownModelName | str | None = None, *, output_type: OutputSpec[OutputDataT] = str, instructions: ( str | SystemPromptFunc[AgentDepsT] | Sequence[str | SystemPromptFunc[AgentDepsT]] | None ) = None, system_prompt: str | Sequence[str] = (), deps_type: type[AgentDepsT] = NoneType, name: str | None = None, model_settings: ModelSettings | None = None, retries: int = 1, output_retries: int | None = None, tools: Sequence[ Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...] ] = (), prepare_tools: ( ToolsPrepareFunc[AgentDepsT] | None ) = None, mcp_servers: Sequence[MCPServer] = (), defer_model_check: bool = False, end_strategy: EndStrategy = "early", instrument: ( InstrumentationSettings | bool | None ) = None, history_processors: ( Sequence[HistoryProcessor[AgentDepsT]] | None ) = None ) -> None ``` ```python __init__( model: Model | KnownModelName | str | None = None, *, result_type: type[OutputDataT] = str, instructions: ( str | SystemPromptFunc[AgentDepsT] | Sequence[str | SystemPromptFunc[AgentDepsT]] | None ) = None, system_prompt: str | Sequence[str] = (), deps_type: type[AgentDepsT] = NoneType, name: str | None = None, model_settings: ModelSettings | None = None, retries: int = 1, result_tool_name: str = DEFAULT_OUTPUT_TOOL_NAME, result_tool_description: str | None = None, result_retries: int | None = None, tools: Sequence[ Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...] ] = (), prepare_tools: ( ToolsPrepareFunc[AgentDepsT] | None ) = None, mcp_servers: Sequence[MCPServer] = (), defer_model_check: bool = False, end_strategy: EndStrategy = "early", instrument: ( InstrumentationSettings | bool | None ) = None, history_processors: ( Sequence[HistoryProcessor[AgentDepsT]] | None ) = None ) -> None ``` ```python __init__( model: Model | KnownModelName | str | None = None, *, output_type: Any = str, instructions: ( str | SystemPromptFunc[AgentDepsT] | Sequence[str | SystemPromptFunc[AgentDepsT]] | None ) = None, system_prompt: str | Sequence[str] = (), deps_type: type[AgentDepsT] = NoneType, name: str | None = None, model_settings: ModelSettings | None = None, retries: int = 1, output_retries: int | None = None, tools: Sequence[ Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...] ] = (), prepare_tools: ( ToolsPrepareFunc[AgentDepsT] | None ) = None, mcp_servers: Sequence[MCPServer] = (), defer_model_check: bool = False, end_strategy: EndStrategy = "early", instrument: ( InstrumentationSettings | bool | None ) = None, history_processors: ( Sequence[HistoryProcessor[AgentDepsT]] | None ) = None, **_deprecated_kwargs: Any ) ``` Create an agent. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `model` | `Model | KnownModelName | str | None` | The default model to use for this agent, if not provide, you must provide the model when calling it. We allow str here since the actual list of allowed models changes frequently. | `None` | | `output_type` | `Any` | The type of the output data, used to validate the data returned by the model, defaults to str. | `str` | | `instructions` | `str | SystemPromptFunc[AgentDepsT] | Sequence[str | SystemPromptFunc[AgentDepsT]] | None` | Instructions to use for this agent, you can also register instructions via a function with instructions. | `None` | | `system_prompt` | `str | Sequence[str]` | Static system prompts to use for this agent, you can also register system prompts via a function with system_prompt. | `()` | | `deps_type` | `type[AgentDepsT]` | The type used for dependency injection, this parameter exists solely to allow you to fully parameterize the agent, and therefore get the best out of static type checking. If you're not using deps, but want type checking to pass, you can set deps=None to satisfy Pyright or add a type hint : Agent\[None, \]. | `NoneType` | | `name` | `str | None` | The name of the agent, used for logging. If None, we try to infer the agent name from the call frame when the agent is first run. | `None` | | `model_settings` | `ModelSettings | None` | Optional model request settings to use for this agent's runs, by default. | `None` | | `retries` | `int` | The default number of retries to allow before raising an error. | `1` | | `output_retries` | `int | None` | The maximum number of retries to allow for result validation, defaults to retries. | `None` | | `tools` | `Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]]` | Tools to register with the agent, you can also register tools via the decorators @agent.tool and @agent.tool_plain. | `()` | | `prepare_tools` | `ToolsPrepareFunc[AgentDepsT] | None` | custom method to prepare the tool definition of all tools for each step. This is useful if you want to customize the definition of multiple tools or you want to register a subset of tools for a given step. See ToolsPrepareFunc | `None` | | `mcp_servers` | `Sequence[MCPServer]` | MCP servers to register with the agent. You should register a MCPServer for each server you want the agent to connect to. | `()` | | `defer_model_check` | `bool` | by default, if you provide a named model, it's evaluated to create a Model instance immediately, which checks for the necessary environment variables. Set this to false to defer the evaluation until the first run. Useful if you want to override the model for testing. | `False` | | `end_strategy` | `EndStrategy` | Strategy for handling tool calls that are requested alongside a final result. See EndStrategy for more information. | `'early'` | | `instrument` | `InstrumentationSettings | bool | None` | Set to True to automatically instrument with OpenTelemetry, which will use Logfire if it's configured. Set to an instance of InstrumentationSettings to customize. If this isn't set, then the last value set by Agent.instrument_all() will be used, which defaults to False. See the Debugging and Monitoring guide for more info. | `None` | | `history_processors` | `Sequence[HistoryProcessor[AgentDepsT]] | None` | Optional list of callables to process the message history before sending it to the model. Each processor takes a list of messages and returns a modified list of messages. Processors can be sync or async and are applied in sequence. | `None` | Source code in `pydantic_ai_slim/pydantic_ai/agent.py` ```python def __init__( self, model: models.Model | models.KnownModelName | str | None = None, *, # TODO change this back to `output_type: _output.OutputType[OutputDataT] = str,` when we remove the overloads output_type: Any = str, instructions: str | _system_prompt.SystemPromptFunc[AgentDepsT] | Sequence[str | _system_prompt.SystemPromptFunc[AgentDepsT]] | None = None, system_prompt: str | Sequence[str] = (), deps_type: type[AgentDepsT] = NoneType, name: str | None = None, model_settings: ModelSettings | None = None, retries: int = 1, output_retries: int | None = None, tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (), prepare_tools: ToolsPrepareFunc[AgentDepsT] | None = None, mcp_servers: Sequence[MCPServer] = (), defer_model_check: bool = False, end_strategy: EndStrategy = 'early', instrument: InstrumentationSettings | bool | None = None, history_processors: Sequence[HistoryProcessor[AgentDepsT]] | None = None, **_deprecated_kwargs: Any, ): """Create an agent. Args: model: The default model to use for this agent, if not provide, you must provide the model when calling it. We allow `str` here since the actual list of allowed models changes frequently. output_type: The type of the output data, used to validate the data returned by the model, defaults to `str`. instructions: Instructions to use for this agent, you can also register instructions via a function with [`instructions`][pydantic_ai.Agent.instructions]. system_prompt: Static system prompts to use for this agent, you can also register system prompts via a function with [`system_prompt`][pydantic_ai.Agent.system_prompt]. deps_type: The type used for dependency injection, this parameter exists solely to allow you to fully parameterize the agent, and therefore get the best out of static type checking. If you're not using deps, but want type checking to pass, you can set `deps=None` to satisfy Pyright or add a type hint `: Agent[None, ]`. name: The name of the agent, used for logging. If `None`, we try to infer the agent name from the call frame when the agent is first run. model_settings: Optional model request settings to use for this agent's runs, by default. retries: The default number of retries to allow before raising an error. output_retries: The maximum number of retries to allow for result validation, defaults to `retries`. tools: Tools to register with the agent, you can also register tools via the decorators [`@agent.tool`][pydantic_ai.Agent.tool] and [`@agent.tool_plain`][pydantic_ai.Agent.tool_plain]. prepare_tools: custom method to prepare the tool definition of all tools for each step. This is useful if you want to customize the definition of multiple tools or you want to register a subset of tools for a given step. See [`ToolsPrepareFunc`][pydantic_ai.tools.ToolsPrepareFunc] mcp_servers: MCP servers to register with the agent. You should register a [`MCPServer`][pydantic_ai.mcp.MCPServer] for each server you want the agent to connect to. defer_model_check: by default, if you provide a [named][pydantic_ai.models.KnownModelName] model, it's evaluated to create a [`Model`][pydantic_ai.models.Model] instance immediately, which checks for the necessary environment variables. Set this to `false` to defer the evaluation until the first run. Useful if you want to [override the model][pydantic_ai.Agent.override] for testing. end_strategy: Strategy for handling tool calls that are requested alongside a final result. See [`EndStrategy`][pydantic_ai.agent.EndStrategy] for more information. instrument: Set to True to automatically instrument with OpenTelemetry, which will use Logfire if it's configured. Set to an instance of [`InstrumentationSettings`][pydantic_ai.agent.InstrumentationSettings] to customize. If this isn't set, then the last value set by [`Agent.instrument_all()`][pydantic_ai.Agent.instrument_all] will be used, which defaults to False. See the [Debugging and Monitoring guide](https://ai.pydantic.dev/logfire/) for more info. history_processors: Optional list of callables to process the message history before sending it to the model. Each processor takes a list of messages and returns a modified list of messages. Processors can be sync or async and are applied in sequence. """ if model is None or defer_model_check: self.model = model else: self.model = models.infer_model(model) self.end_strategy = end_strategy self.name = name self.model_settings = model_settings if 'result_type' in _deprecated_kwargs: if output_type is not str: # pragma: no cover raise TypeError('`result_type` and `output_type` cannot be set at the same time.') warnings.warn('`result_type` is deprecated, use `output_type` instead', DeprecationWarning, stacklevel=2) output_type = _deprecated_kwargs.pop('result_type') self.output_type = output_type self.instrument = instrument self._deps_type = deps_type self._deprecated_result_tool_name = _deprecated_kwargs.pop('result_tool_name', None) if self._deprecated_result_tool_name is not None: warnings.warn( '`result_tool_name` is deprecated, use `output_type` with `ToolOutput` instead', DeprecationWarning, stacklevel=2, ) self._deprecated_result_tool_description = _deprecated_kwargs.pop('result_tool_description', None) if self._deprecated_result_tool_description is not None: warnings.warn( '`result_tool_description` is deprecated, use `output_type` with `ToolOutput` instead', DeprecationWarning, stacklevel=2, ) result_retries = _deprecated_kwargs.pop('result_retries', None) if result_retries is not None: if output_retries is not None: # pragma: no cover raise TypeError('`output_retries` and `result_retries` cannot be set at the same time.') warnings.warn( '`result_retries` is deprecated, use `max_result_retries` instead', DeprecationWarning, stacklevel=2 ) output_retries = result_retries default_output_mode = ( self.model.profile.default_structured_output_mode if isinstance(self.model, models.Model) else None ) _utils.validate_empty_kwargs(_deprecated_kwargs) self._output_schema = _output.OutputSchema[OutputDataT].build( output_type, default_mode=default_output_mode, name=self._deprecated_result_tool_name, description=self._deprecated_result_tool_description, ) self._output_validators = [] self._instructions = '' self._instructions_functions = [] if isinstance(instructions, (str, Callable)): instructions = [instructions] for instruction in instructions or []: if isinstance(instruction, str): self._instructions += instruction + '\n' else: self._instructions_functions.append(_system_prompt.SystemPromptRunner(instruction)) self._instructions = self._instructions.strip() or None self._system_prompts = (system_prompt,) if isinstance(system_prompt, str) else tuple(system_prompt) self._system_prompt_functions = [] self._system_prompt_dynamic_functions = {} self._function_tools = {} self._default_retries = retries self._max_result_retries = output_retries if output_retries is not None else retries self._mcp_servers = mcp_servers self._prepare_tools = prepare_tools self.history_processors = history_processors or [] for tool in tools: if isinstance(tool, Tool): self._register_tool(tool) else: self._register_tool(Tool(tool)) self._override_deps: ContextVar[_utils.Option[AgentDepsT]] = ContextVar('_override_deps', default=None) self._override_model: ContextVar[_utils.Option[models.Model]] = ContextVar('_override_model', default=None) ``` #### end_strategy ```python end_strategy: EndStrategy = end_strategy ``` Strategy for handling tool calls when a final result is found. #### name ```python name: str | None = name ``` The name of the agent, used for logging. If `None`, we try to infer the agent name from the call frame when the agent is first run. #### model_settings ```python model_settings: ModelSettings | None = model_settings ``` Optional model request settings to use for this agents's runs, by default. Note, if `model_settings` is provided by `run`, `run_sync`, or `run_stream`, those settings will be merged with this value, with the runtime argument taking priority. #### output_type ```python output_type: OutputSpec[OutputDataT] = output_type ``` The type of data output by agent runs, used to validate the data returned by the model, defaults to `str`. #### instrument ```python instrument: InstrumentationSettings | bool | None = ( instrument ) ``` Options to automatically instrument with OpenTelemetry. #### instrument_all ```python instrument_all( instrument: InstrumentationSettings | bool = True, ) -> None ``` Set the instrumentation options for all agents where `instrument` is not set. Source code in `pydantic_ai_slim/pydantic_ai/agent.py` ```python @staticmethod def instrument_all(instrument: InstrumentationSettings | bool = True) -> None: """Set the instrumentation options for all agents where `instrument` is not set.""" Agent._instrument_default = instrument ``` #### run ```python run( user_prompt: str | Sequence[UserContent] | None = None, *, output_type: None = None, message_history: list[ModelMessage] | None = None, model: Model | KnownModelName | str | None = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, usage_limits: UsageLimits | None = None, usage: Usage | None = None, infer_name: bool = True ) -> AgentRunResult[OutputDataT] ``` ```python run( user_prompt: str | Sequence[UserContent] | None = None, *, output_type: OutputSpec[RunOutputDataT], message_history: list[ModelMessage] | None = None, model: Model | KnownModelName | str | None = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, usage_limits: UsageLimits | None = None, usage: Usage | None = None, infer_name: bool = True ) -> AgentRunResult[RunOutputDataT] ``` ```python run( user_prompt: str | Sequence[UserContent] | None = None, *, result_type: type[RunOutputDataT], message_history: list[ModelMessage] | None = None, model: Model | KnownModelName | str | None = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, usage_limits: UsageLimits | None = None, usage: Usage | None = None, infer_name: bool = True ) -> AgentRunResult[RunOutputDataT] ``` ```python run( user_prompt: str | Sequence[UserContent] | None = None, *, output_type: OutputSpec[RunOutputDataT] | None = None, message_history: list[ModelMessage] | None = None, model: Model | KnownModelName | str | None = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, usage_limits: UsageLimits | None = None, usage: Usage | None = None, infer_name: bool = True, **_deprecated_kwargs: Never ) -> AgentRunResult[Any] ``` Run the agent with a user prompt in async mode. This method builds an internal agent graph (using system prompts, tools and result schemas) and then runs the graph to completion. The result of the run is returned. Example: ```python from pydantic_ai import Agent agent = Agent('openai:gpt-4o') async def main(): agent_run = await agent.run('What is the capital of France?') print(agent_run.output) #> Paris ``` Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `user_prompt` | `str | Sequence[UserContent] | None` | User input to start/continue the conversation. | `None` | | `output_type` | `OutputSpec[RunOutputDataT] | None` | Custom output type to use for this run, output_type may only be used if the agent has no output validators since output validators would expect an argument that matches the agent's output type. | `None` | | `message_history` | `list[ModelMessage] | None` | History of the conversation so far. | `None` | | `model` | `Model | KnownModelName | str | None` | Optional model to use for this run, required if model was not set when creating the agent. | `None` | | `deps` | `AgentDepsT` | Optional dependencies to use for this run. | `None` | | `model_settings` | `ModelSettings | None` | Optional settings to use for this model's request. | `None` | | `usage_limits` | `UsageLimits | None` | Optional limits on model request count or token usage. | `None` | | `usage` | `Usage | None` | Optional usage to start with, useful for resuming a conversation or agents used in tools. | `None` | | `infer_name` | `bool` | Whether to try to infer the agent name from the call frame if it's not set. | `True` | Returns: | Type | Description | | --- | --- | | `AgentRunResult[Any]` | The result of the run. | Source code in `pydantic_ai_slim/pydantic_ai/agent.py` ````python async def run( self, user_prompt: str | Sequence[_messages.UserContent] | None = None, *, output_type: OutputSpec[RunOutputDataT] | None = None, message_history: list[_messages.ModelMessage] | None = None, model: models.Model | models.KnownModelName | str | None = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.Usage | None = None, infer_name: bool = True, **_deprecated_kwargs: Never, ) -> AgentRunResult[Any]: """Run the agent with a user prompt in async mode. This method builds an internal agent graph (using system prompts, tools and result schemas) and then runs the graph to completion. The result of the run is returned. Example: ```python from pydantic_ai import Agent agent = Agent('openai:gpt-4o') async def main(): agent_run = await agent.run('What is the capital of France?') print(agent_run.output) #> Paris ``` Args: user_prompt: User input to start/continue the conversation. output_type: Custom output type to use for this run, `output_type` may only be used if the agent has no output validators since output validators would expect an argument that matches the agent's output type. message_history: History of the conversation so far. model: Optional model to use for this run, required if `model` was not set when creating the agent. deps: Optional dependencies to use for this run. model_settings: Optional settings to use for this model's request. usage_limits: Optional limits on model request count or token usage. usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. infer_name: Whether to try to infer the agent name from the call frame if it's not set. Returns: The result of the run. """ if infer_name and self.name is None: self._infer_name(inspect.currentframe()) if 'result_type' in _deprecated_kwargs: # pragma: no cover if output_type is not str: raise TypeError('`result_type` and `output_type` cannot be set at the same time.') warnings.warn('`result_type` is deprecated, use `output_type` instead.', DeprecationWarning, stacklevel=2) output_type = _deprecated_kwargs.pop('result_type') _utils.validate_empty_kwargs(_deprecated_kwargs) async with self.iter( user_prompt=user_prompt, output_type=output_type, message_history=message_history, model=model, deps=deps, model_settings=model_settings, usage_limits=usage_limits, usage=usage, ) as agent_run: async for _ in agent_run: pass assert agent_run.result is not None, 'The graph run did not finish properly' return agent_run.result ```` #### iter ```python iter( user_prompt: str | Sequence[UserContent] | None, *, output_type: None = None, message_history: list[ModelMessage] | None = None, model: Model | KnownModelName | str | None = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, usage_limits: UsageLimits | None = None, usage: Usage | None = None, infer_name: bool = True, **_deprecated_kwargs: Never ) -> AbstractAsyncContextManager[ AgentRun[AgentDepsT, OutputDataT] ] ``` ```python iter( user_prompt: str | Sequence[UserContent] | None, *, output_type: OutputSpec[RunOutputDataT], message_history: list[ModelMessage] | None = None, model: Model | KnownModelName | str | None = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, usage_limits: UsageLimits | None = None, usage: Usage | None = None, infer_name: bool = True, **_deprecated_kwargs: Never ) -> AbstractAsyncContextManager[ AgentRun[AgentDepsT, RunOutputDataT] ] ``` ```python iter( user_prompt: str | Sequence[UserContent] | None, *, result_type: type[RunOutputDataT], message_history: list[ModelMessage] | None = None, model: Model | KnownModelName | str | None = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, usage_limits: UsageLimits | None = None, usage: Usage | None = None, infer_name: bool = True ) -> AbstractAsyncContextManager[AgentRun[AgentDepsT, Any]] ``` ```python iter( user_prompt: str | Sequence[UserContent] | None = None, *, output_type: OutputSpec[RunOutputDataT] | None = None, message_history: list[ModelMessage] | None = None, model: Model | KnownModelName | str | None = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, usage_limits: UsageLimits | None = None, usage: Usage | None = None, infer_name: bool = True, **_deprecated_kwargs: Never ) -> AsyncIterator[AgentRun[AgentDepsT, Any]] ``` A contextmanager which can be used to iterate over the agent graph's nodes as they are executed. This method builds an internal agent graph (using system prompts, tools and output schemas) and then returns an `AgentRun` object. The `AgentRun` can be used to async-iterate over the nodes of the graph as they are executed. This is the API to use if you want to consume the outputs coming from each LLM model response, or the stream of events coming from the execution of tools. The `AgentRun` also provides methods to access the full message history, new messages, and usage statistics, and the final result of the run once it has completed. For more details, see the documentation of `AgentRun`. Example: ```python from pydantic_ai import Agent agent = Agent('openai:gpt-4o') async def main(): nodes = [] async with agent.iter('What is the capital of France?') as agent_run: async for node in agent_run: nodes.append(node) print(nodes) ''' [ UserPromptNode( user_prompt='What is the capital of France?', instructions=None, instructions_functions=[], system_prompts=(), system_prompt_functions=[], system_prompt_dynamic_functions={}, ), ModelRequestNode( request=ModelRequest( parts=[ UserPromptPart( content='What is the capital of France?', timestamp=datetime.datetime(...), ) ] ) ), CallToolsNode( model_response=ModelResponse( parts=[TextPart(content='Paris')], usage=Usage( requests=1, request_tokens=56, response_tokens=1, total_tokens=57 ), model_name='gpt-4o', timestamp=datetime.datetime(...), ) ), End(data=FinalResult(output='Paris')), ] ''' print(agent_run.result.output) #> Paris ``` Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `user_prompt` | `str | Sequence[UserContent] | None` | User input to start/continue the conversation. | `None` | | `output_type` | `OutputSpec[RunOutputDataT] | None` | Custom output type to use for this run, output_type may only be used if the agent has no output validators since output validators would expect an argument that matches the agent's output type. | `None` | | `message_history` | `list[ModelMessage] | None` | History of the conversation so far. | `None` | | `model` | `Model | KnownModelName | str | None` | Optional model to use for this run, required if model was not set when creating the agent. | `None` | | `deps` | `AgentDepsT` | Optional dependencies to use for this run. | `None` | | `model_settings` | `ModelSettings | None` | Optional settings to use for this model's request. | `None` | | `usage_limits` | `UsageLimits | None` | Optional limits on model request count or token usage. | `None` | | `usage` | `Usage | None` | Optional usage to start with, useful for resuming a conversation or agents used in tools. | `None` | | `infer_name` | `bool` | Whether to try to infer the agent name from the call frame if it's not set. | `True` | Returns: | Type | Description | | --- | --- | | `AsyncIterator[AgentRun[AgentDepsT, Any]]` | The result of the run. | Source code in `pydantic_ai_slim/pydantic_ai/agent.py` ````python @asynccontextmanager async def iter( self, user_prompt: str | Sequence[_messages.UserContent] | None = None, *, output_type: OutputSpec[RunOutputDataT] | None = None, message_history: list[_messages.ModelMessage] | None = None, model: models.Model | models.KnownModelName | str | None = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.Usage | None = None, infer_name: bool = True, **_deprecated_kwargs: Never, ) -> AsyncIterator[AgentRun[AgentDepsT, Any]]: """A contextmanager which can be used to iterate over the agent graph's nodes as they are executed. This method builds an internal agent graph (using system prompts, tools and output schemas) and then returns an `AgentRun` object. The `AgentRun` can be used to async-iterate over the nodes of the graph as they are executed. This is the API to use if you want to consume the outputs coming from each LLM model response, or the stream of events coming from the execution of tools. The `AgentRun` also provides methods to access the full message history, new messages, and usage statistics, and the final result of the run once it has completed. For more details, see the documentation of `AgentRun`. Example: ```python from pydantic_ai import Agent agent = Agent('openai:gpt-4o') async def main(): nodes = [] async with agent.iter('What is the capital of France?') as agent_run: async for node in agent_run: nodes.append(node) print(nodes) ''' [ UserPromptNode( user_prompt='What is the capital of France?', instructions=None, instructions_functions=[], system_prompts=(), system_prompt_functions=[], system_prompt_dynamic_functions={}, ), ModelRequestNode( request=ModelRequest( parts=[ UserPromptPart( content='What is the capital of France?', timestamp=datetime.datetime(...), ) ] ) ), CallToolsNode( model_response=ModelResponse( parts=[TextPart(content='Paris')], usage=Usage( requests=1, request_tokens=56, response_tokens=1, total_tokens=57 ), model_name='gpt-4o', timestamp=datetime.datetime(...), ) ), End(data=FinalResult(output='Paris')), ] ''' print(agent_run.result.output) #> Paris ``` Args: user_prompt: User input to start/continue the conversation. output_type: Custom output type to use for this run, `output_type` may only be used if the agent has no output validators since output validators would expect an argument that matches the agent's output type. message_history: History of the conversation so far. model: Optional model to use for this run, required if `model` was not set when creating the agent. deps: Optional dependencies to use for this run. model_settings: Optional settings to use for this model's request. usage_limits: Optional limits on model request count or token usage. usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. infer_name: Whether to try to infer the agent name from the call frame if it's not set. Returns: The result of the run. """ if infer_name and self.name is None: self._infer_name(inspect.currentframe()) model_used = self._get_model(model) del model if 'result_type' in _deprecated_kwargs: # pragma: no cover if output_type is not str: raise TypeError('`result_type` and `output_type` cannot be set at the same time.') warnings.warn('`result_type` is deprecated, use `output_type` instead.', DeprecationWarning, stacklevel=2) output_type = _deprecated_kwargs.pop('result_type') _utils.validate_empty_kwargs(_deprecated_kwargs) deps = self._get_deps(deps) new_message_index = len(message_history) if message_history else 0 output_schema = self._prepare_output_schema(output_type, model_used.profile) output_type_ = output_type or self.output_type # Build the graph graph: Graph[_agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any], FinalResult[Any]] = ( _agent_graph.build_agent_graph(self.name, self._deps_type, output_type_) ) # Build the initial state usage = usage or _usage.Usage() state = _agent_graph.GraphAgentState( message_history=message_history[:] if message_history else [], usage=usage, retries=0, run_step=0, ) # We consider it a user error if a user tries to restrict the result type while having an output validator that # may change the result type from the restricted type to something else. Therefore, we consider the following # typecast reasonable, even though it is possible to violate it with otherwise-type-checked code. output_validators = cast(list[_output.OutputValidator[AgentDepsT, RunOutputDataT]], self._output_validators) model_settings = merge_model_settings(self.model_settings, model_settings) usage_limits = usage_limits or _usage.UsageLimits() if isinstance(model_used, InstrumentedModel): instrumentation_settings = model_used.settings tracer = model_used.settings.tracer else: instrumentation_settings = None tracer = NoOpTracer() agent_name = self.name or 'agent' run_span = tracer.start_span( 'agent run', attributes={ 'model_name': model_used.model_name if model_used else 'no-model', 'agent_name': agent_name, 'logfire.msg': f'{agent_name} run', }, ) async def get_instructions(run_context: RunContext[AgentDepsT]) -> str | None: parts = [ self._instructions, *[await func.run(run_context) for func in self._instructions_functions], ] model_profile = model_used.profile if isinstance(output_schema, _output.PromptedOutputSchema): instructions = output_schema.instructions(model_profile.prompted_output_template) parts.append(instructions) parts = [p for p in parts if p] if not parts: return None return '\n\n'.join(parts).strip() # Copy the function tools so that retry state is agent-run-specific # Note that the retry count is reset to 0 when this happens due to the `default=0` and `init=False`. run_function_tools = {k: dataclasses.replace(v) for k, v in self._function_tools.items()} graph_deps = _agent_graph.GraphAgentDeps[AgentDepsT, RunOutputDataT]( user_deps=deps, prompt=user_prompt, new_message_index=new_message_index, model=model_used, model_settings=model_settings, usage_limits=usage_limits, max_result_retries=self._max_result_retries, end_strategy=self.end_strategy, output_schema=output_schema, output_validators=output_validators, history_processors=self.history_processors, function_tools=run_function_tools, mcp_servers=self._mcp_servers, default_retries=self._default_retries, tracer=tracer, prepare_tools=self._prepare_tools, get_instructions=get_instructions, instrumentation_settings=instrumentation_settings, ) start_node = _agent_graph.UserPromptNode[AgentDepsT]( user_prompt=user_prompt, instructions=self._instructions, instructions_functions=self._instructions_functions, system_prompts=self._system_prompts, system_prompt_functions=self._system_prompt_functions, system_prompt_dynamic_functions=self._system_prompt_dynamic_functions, ) try: async with graph.iter( start_node, state=state, deps=graph_deps, span=use_span(run_span) if run_span.is_recording() else None, infer_name=False, ) as graph_run: agent_run = AgentRun(graph_run) yield agent_run if (final_result := agent_run.result) is not None and run_span.is_recording(): run_span.set_attribute( 'final_result', ( final_result.output if isinstance(final_result.output, str) else json.dumps(InstrumentedModel.serialize_any(final_result.output)) ), ) finally: try: if instrumentation_settings and run_span.is_recording(): run_span.set_attributes(self._run_span_end_attributes(state, usage, instrumentation_settings)) finally: run_span.end() ```` #### run_sync ```python run_sync( user_prompt: str | Sequence[UserContent] | None = None, *, message_history: list[ModelMessage] | None = None, model: Model | KnownModelName | str | None = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, usage_limits: UsageLimits | None = None, usage: Usage | None = None, infer_name: bool = True ) -> AgentRunResult[OutputDataT] ``` ```python run_sync( user_prompt: str | Sequence[UserContent] | None = None, *, output_type: OutputSpec[RunOutputDataT] | None = None, message_history: list[ModelMessage] | None = None, model: Model | KnownModelName | str | None = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, usage_limits: UsageLimits | None = None, usage: Usage | None = None, infer_name: bool = True ) -> AgentRunResult[RunOutputDataT] ``` ```python run_sync( user_prompt: str | Sequence[UserContent] | None = None, *, result_type: type[RunOutputDataT], message_history: list[ModelMessage] | None = None, model: Model | KnownModelName | str | None = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, usage_limits: UsageLimits | None = None, usage: Usage | None = None, infer_name: bool = True ) -> AgentRunResult[RunOutputDataT] ``` ```python run_sync( user_prompt: str | Sequence[UserContent] | None = None, *, output_type: OutputSpec[RunOutputDataT] | None = None, message_history: list[ModelMessage] | None = None, model: Model | KnownModelName | str | None = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, usage_limits: UsageLimits | None = None, usage: Usage | None = None, infer_name: bool = True, **_deprecated_kwargs: Never ) -> AgentRunResult[Any] ``` Synchronously run the agent with a user prompt. This is a convenience method that wraps self.run with `loop.run_until_complete(...)`. You therefore can't use this method inside async code or if there's an active event loop. Example: ```python from pydantic_ai import Agent agent = Agent('openai:gpt-4o') result_sync = agent.run_sync('What is the capital of Italy?') print(result_sync.output) #> Rome ``` Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `user_prompt` | `str | Sequence[UserContent] | None` | User input to start/continue the conversation. | `None` | | `output_type` | `OutputSpec[RunOutputDataT] | None` | Custom output type to use for this run, output_type may only be used if the agent has no output validators since output validators would expect an argument that matches the agent's output type. | `None` | | `message_history` | `list[ModelMessage] | None` | History of the conversation so far. | `None` | | `model` | `Model | KnownModelName | str | None` | Optional model to use for this run, required if model was not set when creating the agent. | `None` | | `deps` | `AgentDepsT` | Optional dependencies to use for this run. | `None` | | `model_settings` | `ModelSettings | None` | Optional settings to use for this model's request. | `None` | | `usage_limits` | `UsageLimits | None` | Optional limits on model request count or token usage. | `None` | | `usage` | `Usage | None` | Optional usage to start with, useful for resuming a conversation or agents used in tools. | `None` | | `infer_name` | `bool` | Whether to try to infer the agent name from the call frame if it's not set. | `True` | Returns: | Type | Description | | --- | --- | | `AgentRunResult[Any]` | The result of the run. | Source code in `pydantic_ai_slim/pydantic_ai/agent.py` ````python def run_sync( self, user_prompt: str | Sequence[_messages.UserContent] | None = None, *, output_type: OutputSpec[RunOutputDataT] | None = None, message_history: list[_messages.ModelMessage] | None = None, model: models.Model | models.KnownModelName | str | None = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.Usage | None = None, infer_name: bool = True, **_deprecated_kwargs: Never, ) -> AgentRunResult[Any]: """Synchronously run the agent with a user prompt. This is a convenience method that wraps [`self.run`][pydantic_ai.Agent.run] with `loop.run_until_complete(...)`. You therefore can't use this method inside async code or if there's an active event loop. Example: ```python from pydantic_ai import Agent agent = Agent('openai:gpt-4o') result_sync = agent.run_sync('What is the capital of Italy?') print(result_sync.output) #> Rome ``` Args: user_prompt: User input to start/continue the conversation. output_type: Custom output type to use for this run, `output_type` may only be used if the agent has no output validators since output validators would expect an argument that matches the agent's output type. message_history: History of the conversation so far. model: Optional model to use for this run, required if `model` was not set when creating the agent. deps: Optional dependencies to use for this run. model_settings: Optional settings to use for this model's request. usage_limits: Optional limits on model request count or token usage. usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. infer_name: Whether to try to infer the agent name from the call frame if it's not set. Returns: The result of the run. """ if infer_name and self.name is None: self._infer_name(inspect.currentframe()) if 'result_type' in _deprecated_kwargs: # pragma: no cover if output_type is not str: raise TypeError('`result_type` and `output_type` cannot be set at the same time.') warnings.warn('`result_type` is deprecated, use `output_type` instead.', DeprecationWarning, stacklevel=2) output_type = _deprecated_kwargs.pop('result_type') _utils.validate_empty_kwargs(_deprecated_kwargs) return get_event_loop().run_until_complete( self.run( user_prompt, output_type=output_type, message_history=message_history, model=model, deps=deps, model_settings=model_settings, usage_limits=usage_limits, usage=usage, infer_name=False, ) ) ```` #### run_stream ```python run_stream( user_prompt: str | Sequence[UserContent] | None = None, *, message_history: list[ModelMessage] | None = None, model: Model | KnownModelName | None = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, usage_limits: UsageLimits | None = None, usage: Usage | None = None, infer_name: bool = True ) -> AbstractAsyncContextManager[ StreamedRunResult[AgentDepsT, OutputDataT] ] ``` ```python run_stream( user_prompt: str | Sequence[UserContent], *, output_type: OutputSpec[RunOutputDataT], message_history: list[ModelMessage] | None = None, model: Model | KnownModelName | str | None = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, usage_limits: UsageLimits | None = None, usage: Usage | None = None, infer_name: bool = True ) -> AbstractAsyncContextManager[ StreamedRunResult[AgentDepsT, RunOutputDataT] ] ``` ```python run_stream( user_prompt: str | Sequence[UserContent] | None = None, *, result_type: type[RunOutputDataT], message_history: list[ModelMessage] | None = None, model: Model | KnownModelName | str | None = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, usage_limits: UsageLimits | None = None, usage: Usage | None = None, infer_name: bool = True ) -> AbstractAsyncContextManager[ StreamedRunResult[AgentDepsT, RunOutputDataT] ] ``` ```python run_stream( user_prompt: str | Sequence[UserContent] | None = None, *, output_type: OutputSpec[RunOutputDataT] | None = None, message_history: list[ModelMessage] | None = None, model: Model | KnownModelName | str | None = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, usage_limits: UsageLimits | None = None, usage: Usage | None = None, infer_name: bool = True, **_deprecated_kwargs: Never ) -> AsyncIterator[StreamedRunResult[AgentDepsT, Any]] ``` Run the agent with a user prompt in async mode, returning a streamed response. Example: ```python from pydantic_ai import Agent agent = Agent('openai:gpt-4o') async def main(): async with agent.run_stream('What is the capital of the UK?') as response: print(await response.get_output()) #> London ``` Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `user_prompt` | `str | Sequence[UserContent] | None` | User input to start/continue the conversation. | `None` | | `output_type` | `OutputSpec[RunOutputDataT] | None` | Custom output type to use for this run, output_type may only be used if the agent has no output validators since output validators would expect an argument that matches the agent's output type. | `None` | | `message_history` | `list[ModelMessage] | None` | History of the conversation so far. | `None` | | `model` | `Model | KnownModelName | str | None` | Optional model to use for this run, required if model was not set when creating the agent. | `None` | | `deps` | `AgentDepsT` | Optional dependencies to use for this run. | `None` | | `model_settings` | `ModelSettings | None` | Optional settings to use for this model's request. | `None` | | `usage_limits` | `UsageLimits | None` | Optional limits on model request count or token usage. | `None` | | `usage` | `Usage | None` | Optional usage to start with, useful for resuming a conversation or agents used in tools. | `None` | | `infer_name` | `bool` | Whether to try to infer the agent name from the call frame if it's not set. | `True` | Returns: | Type | Description | | --- | --- | | `AsyncIterator[StreamedRunResult[AgentDepsT, Any]]` | The result of the run. | Source code in `pydantic_ai_slim/pydantic_ai/agent.py` ````python @asynccontextmanager async def run_stream( # noqa C901 self, user_prompt: str | Sequence[_messages.UserContent] | None = None, *, output_type: OutputSpec[RunOutputDataT] | None = None, message_history: list[_messages.ModelMessage] | None = None, model: models.Model | models.KnownModelName | str | None = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.Usage | None = None, infer_name: bool = True, **_deprecated_kwargs: Never, ) -> AsyncIterator[result.StreamedRunResult[AgentDepsT, Any]]: """Run the agent with a user prompt in async mode, returning a streamed response. Example: ```python from pydantic_ai import Agent agent = Agent('openai:gpt-4o') async def main(): async with agent.run_stream('What is the capital of the UK?') as response: print(await response.get_output()) #> London ``` Args: user_prompt: User input to start/continue the conversation. output_type: Custom output type to use for this run, `output_type` may only be used if the agent has no output validators since output validators would expect an argument that matches the agent's output type. message_history: History of the conversation so far. model: Optional model to use for this run, required if `model` was not set when creating the agent. deps: Optional dependencies to use for this run. model_settings: Optional settings to use for this model's request. usage_limits: Optional limits on model request count or token usage. usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. infer_name: Whether to try to infer the agent name from the call frame if it's not set. Returns: The result of the run. """ # TODO: We need to deprecate this now that we have the `iter` method. # Before that, though, we should add an event for when we reach the final result of the stream. if infer_name and self.name is None: # f_back because `asynccontextmanager` adds one frame if frame := inspect.currentframe(): # pragma: no branch self._infer_name(frame.f_back) if 'result_type' in _deprecated_kwargs: # pragma: no cover if output_type is not str: raise TypeError('`result_type` and `output_type` cannot be set at the same time.') warnings.warn('`result_type` is deprecated, use `output_type` instead.', DeprecationWarning, stacklevel=2) output_type = _deprecated_kwargs.pop('result_type') _utils.validate_empty_kwargs(_deprecated_kwargs) yielded = False async with self.iter( user_prompt, output_type=output_type, message_history=message_history, model=model, deps=deps, model_settings=model_settings, usage_limits=usage_limits, usage=usage, infer_name=False, ) as agent_run: first_node = agent_run.next_node # start with the first node assert isinstance(first_node, _agent_graph.UserPromptNode) # the first node should be a user prompt node node = first_node while True: if self.is_model_request_node(node): graph_ctx = agent_run.ctx async with node._stream(graph_ctx) as streamed_response: # pyright: ignore[reportPrivateUsage] async def stream_to_final( s: models.StreamedResponse, ) -> FinalResult[models.StreamedResponse] | None: output_schema = graph_ctx.deps.output_schema async for maybe_part_event in streamed_response: if isinstance(maybe_part_event, _messages.PartStartEvent): new_part = maybe_part_event.part if isinstance(new_part, _messages.TextPart) and isinstance( output_schema, _output.TextOutputSchema ): return FinalResult(s, None, None) elif isinstance(new_part, _messages.ToolCallPart) and isinstance( output_schema, _output.ToolOutputSchema ): # pragma: no branch for call, _ in output_schema.find_tool([new_part]): return FinalResult(s, call.tool_name, call.tool_call_id) return None final_result_details = await stream_to_final(streamed_response) if final_result_details is not None: if yielded: raise exceptions.AgentRunError('Agent run produced final results') # pragma: no cover yielded = True messages = graph_ctx.state.message_history.copy() async def on_complete() -> None: """Called when the stream has completed. The model response will have been added to messages by now by `StreamedRunResult._marked_completed`. """ last_message = messages[-1] assert isinstance(last_message, _messages.ModelResponse) tool_calls = [ part for part in last_message.parts if isinstance(part, _messages.ToolCallPart) ] parts: list[_messages.ModelRequestPart] = [] async for _event in _agent_graph.process_function_tools( tool_calls, final_result_details.tool_name, final_result_details.tool_call_id, graph_ctx, parts, ): pass # TODO: Should we do something here related to the retry count? # Maybe we should move the incrementing of the retry count to where we actually make a request? # if any(isinstance(part, _messages.RetryPromptPart) for part in parts): # ctx.state.increment_retries(ctx.deps.max_result_retries) if parts: messages.append(_messages.ModelRequest(parts)) yield StreamedRunResult( messages, graph_ctx.deps.new_message_index, graph_ctx.deps.usage_limits, streamed_response, graph_ctx.deps.output_schema, _agent_graph.build_run_context(graph_ctx), graph_ctx.deps.output_validators, final_result_details.tool_name, on_complete, ) break next_node = await agent_run.next(node) if not isinstance(next_node, _agent_graph.AgentNode): raise exceptions.AgentRunError( # pragma: no cover 'Should have produced a StreamedRunResult before getting here' ) node = cast(_agent_graph.AgentNode[Any, Any], next_node) if not yielded: raise exceptions.AgentRunError('Agent run finished without producing a final result') # pragma: no cover ```` #### override ```python override( *, deps: AgentDepsT | Unset = UNSET, model: Model | KnownModelName | str | Unset = UNSET ) -> Iterator[None] ``` Context manager to temporarily override agent dependencies and model. This is particularly useful when testing. You can find an example of this [here](../../testing/#overriding-model-via-pytest-fixtures). Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `deps` | `AgentDepsT | Unset` | The dependencies to use instead of the dependencies passed to the agent run. | `UNSET` | | `model` | `Model | KnownModelName | str | Unset` | The model to use instead of the model passed to the agent run. | `UNSET` | Source code in `pydantic_ai_slim/pydantic_ai/agent.py` ```python @contextmanager def override( self, *, deps: AgentDepsT | _utils.Unset = _utils.UNSET, model: models.Model | models.KnownModelName | str | _utils.Unset = _utils.UNSET, ) -> Iterator[None]: """Context manager to temporarily override agent dependencies and model. This is particularly useful when testing. You can find an example of this [here](../testing.md#overriding-model-via-pytest-fixtures). Args: deps: The dependencies to use instead of the dependencies passed to the agent run. model: The model to use instead of the model passed to the agent run. """ if _utils.is_set(deps): deps_token = self._override_deps.set(_utils.Some(deps)) else: deps_token = None if _utils.is_set(model): model_token = self._override_model.set(_utils.Some(models.infer_model(model))) else: model_token = None try: yield finally: if deps_token is not None: self._override_deps.reset(deps_token) if model_token is not None: self._override_model.reset(model_token) ``` #### instructions ```python instructions( func: Callable[[RunContext[AgentDepsT]], str], ) -> Callable[[RunContext[AgentDepsT]], str] ``` ```python instructions( func: Callable[ [RunContext[AgentDepsT]], Awaitable[str] ], ) -> Callable[[RunContext[AgentDepsT]], Awaitable[str]] ``` ```python instructions(func: Callable[[], str]) -> Callable[[], str] ``` ```python instructions( func: Callable[[], Awaitable[str]], ) -> Callable[[], Awaitable[str]] ``` ```python instructions() -> Callable[ [SystemPromptFunc[AgentDepsT]], SystemPromptFunc[AgentDepsT], ] ``` ```python instructions( func: SystemPromptFunc[AgentDepsT] | None = None, ) -> ( Callable[ [SystemPromptFunc[AgentDepsT]], SystemPromptFunc[AgentDepsT], ] | SystemPromptFunc[AgentDepsT] ) ``` Decorator to register an instructions function. Optionally takes RunContext as its only argument. Can decorate a sync or async functions. The decorator can be used bare (`agent.instructions`). Overloads for every possible signature of `instructions` are included so the decorator doesn't obscure the type of the function. Example: ```python from pydantic_ai import Agent, RunContext agent = Agent('test', deps_type=str) @agent.instructions def simple_instructions() -> str: return 'foobar' @agent.instructions async def async_instructions(ctx: RunContext[str]) -> str: return f'{ctx.deps} is the best' ``` Source code in `pydantic_ai_slim/pydantic_ai/agent.py` ````python def instructions( self, func: _system_prompt.SystemPromptFunc[AgentDepsT] | None = None, /, ) -> ( Callable[[_system_prompt.SystemPromptFunc[AgentDepsT]], _system_prompt.SystemPromptFunc[AgentDepsT]] | _system_prompt.SystemPromptFunc[AgentDepsT] ): """Decorator to register an instructions function. Optionally takes [`RunContext`][pydantic_ai.tools.RunContext] as its only argument. Can decorate a sync or async functions. The decorator can be used bare (`agent.instructions`). Overloads for every possible signature of `instructions` are included so the decorator doesn't obscure the type of the function. Example: ```python from pydantic_ai import Agent, RunContext agent = Agent('test', deps_type=str) @agent.instructions def simple_instructions() -> str: return 'foobar' @agent.instructions async def async_instructions(ctx: RunContext[str]) -> str: return f'{ctx.deps} is the best' ``` """ if func is None: def decorator( func_: _system_prompt.SystemPromptFunc[AgentDepsT], ) -> _system_prompt.SystemPromptFunc[AgentDepsT]: self._instructions_functions.append(_system_prompt.SystemPromptRunner(func_)) return func_ return decorator else: self._instructions_functions.append(_system_prompt.SystemPromptRunner(func)) return func ```` #### system_prompt ```python system_prompt( func: Callable[[RunContext[AgentDepsT]], str], ) -> Callable[[RunContext[AgentDepsT]], str] ``` ```python system_prompt( func: Callable[ [RunContext[AgentDepsT]], Awaitable[str] ], ) -> Callable[[RunContext[AgentDepsT]], Awaitable[str]] ``` ```python system_prompt(func: Callable[[], str]) -> Callable[[], str] ``` ```python system_prompt( func: Callable[[], Awaitable[str]], ) -> Callable[[], Awaitable[str]] ``` ```python system_prompt(*, dynamic: bool = False) -> Callable[ [SystemPromptFunc[AgentDepsT]], SystemPromptFunc[AgentDepsT], ] ``` ```python system_prompt( func: SystemPromptFunc[AgentDepsT] | None = None, /, *, dynamic: bool = False, ) -> ( Callable[ [SystemPromptFunc[AgentDepsT]], SystemPromptFunc[AgentDepsT], ] | SystemPromptFunc[AgentDepsT] ) ``` Decorator to register a system prompt function. Optionally takes RunContext as its only argument. Can decorate a sync or async functions. The decorator can be used either bare (`agent.system_prompt`) or as a function call (`agent.system_prompt(...)`), see the examples below. Overloads for every possible signature of `system_prompt` are included so the decorator doesn't obscure the type of the function, see `tests/typed_agent.py` for tests. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `func` | `SystemPromptFunc[AgentDepsT] | None` | The function to decorate | `None` | | `dynamic` | `bool` | If True, the system prompt will be reevaluated even when messages_history is provided, see SystemPromptPart.dynamic_ref | `False` | Example: ```python from pydantic_ai import Agent, RunContext agent = Agent('test', deps_type=str) @agent.system_prompt def simple_system_prompt() -> str: return 'foobar' @agent.system_prompt(dynamic=True) async def async_system_prompt(ctx: RunContext[str]) -> str: return f'{ctx.deps} is the best' ``` Source code in `pydantic_ai_slim/pydantic_ai/agent.py` ````python def system_prompt( self, func: _system_prompt.SystemPromptFunc[AgentDepsT] | None = None, /, *, dynamic: bool = False, ) -> ( Callable[[_system_prompt.SystemPromptFunc[AgentDepsT]], _system_prompt.SystemPromptFunc[AgentDepsT]] | _system_prompt.SystemPromptFunc[AgentDepsT] ): """Decorator to register a system prompt function. Optionally takes [`RunContext`][pydantic_ai.tools.RunContext] as its only argument. Can decorate a sync or async functions. The decorator can be used either bare (`agent.system_prompt`) or as a function call (`agent.system_prompt(...)`), see the examples below. Overloads for every possible signature of `system_prompt` are included so the decorator doesn't obscure the type of the function, see `tests/typed_agent.py` for tests. Args: func: The function to decorate dynamic: If True, the system prompt will be reevaluated even when `messages_history` is provided, see [`SystemPromptPart.dynamic_ref`][pydantic_ai.messages.SystemPromptPart.dynamic_ref] Example: ```python from pydantic_ai import Agent, RunContext agent = Agent('test', deps_type=str) @agent.system_prompt def simple_system_prompt() -> str: return 'foobar' @agent.system_prompt(dynamic=True) async def async_system_prompt(ctx: RunContext[str]) -> str: return f'{ctx.deps} is the best' ``` """ if func is None: def decorator( func_: _system_prompt.SystemPromptFunc[AgentDepsT], ) -> _system_prompt.SystemPromptFunc[AgentDepsT]: runner = _system_prompt.SystemPromptRunner[AgentDepsT](func_, dynamic=dynamic) self._system_prompt_functions.append(runner) if dynamic: # pragma: lax no cover self._system_prompt_dynamic_functions[func_.__qualname__] = runner return func_ return decorator else: assert not dynamic, "dynamic can't be True in this case" self._system_prompt_functions.append(_system_prompt.SystemPromptRunner[AgentDepsT](func, dynamic=dynamic)) return func ```` #### output_validator ```python output_validator( func: Callable[ [RunContext[AgentDepsT], OutputDataT], OutputDataT ], ) -> Callable[ [RunContext[AgentDepsT], OutputDataT], OutputDataT ] ``` ```python output_validator( func: Callable[ [RunContext[AgentDepsT], OutputDataT], Awaitable[OutputDataT], ], ) -> Callable[ [RunContext[AgentDepsT], OutputDataT], Awaitable[OutputDataT], ] ``` ```python output_validator( func: Callable[[OutputDataT], OutputDataT], ) -> Callable[[OutputDataT], OutputDataT] ``` ```python output_validator( func: Callable[[OutputDataT], Awaitable[OutputDataT]], ) -> Callable[[OutputDataT], Awaitable[OutputDataT]] ``` ```python output_validator( func: OutputValidatorFunc[AgentDepsT, OutputDataT], ) -> OutputValidatorFunc[AgentDepsT, OutputDataT] ``` Decorator to register an output validator function. Optionally takes RunContext as its first argument. Can decorate a sync or async functions. Overloads for every possible signature of `output_validator` are included so the decorator doesn't obscure the type of the function, see `tests/typed_agent.py` for tests. Example: ```python from pydantic_ai import Agent, ModelRetry, RunContext agent = Agent('test', deps_type=str) @agent.output_validator def output_validator_simple(data: str) -> str: if 'wrong' in data: raise ModelRetry('wrong response') return data @agent.output_validator async def output_validator_deps(ctx: RunContext[str], data: str) -> str: if ctx.deps in data: raise ModelRetry('wrong response') return data result = agent.run_sync('foobar', deps='spam') print(result.output) #> success (no tool calls) ``` Source code in `pydantic_ai_slim/pydantic_ai/agent.py` ````python def output_validator( self, func: _output.OutputValidatorFunc[AgentDepsT, OutputDataT], / ) -> _output.OutputValidatorFunc[AgentDepsT, OutputDataT]: """Decorator to register an output validator function. Optionally takes [`RunContext`][pydantic_ai.tools.RunContext] as its first argument. Can decorate a sync or async functions. Overloads for every possible signature of `output_validator` are included so the decorator doesn't obscure the type of the function, see `tests/typed_agent.py` for tests. Example: ```python from pydantic_ai import Agent, ModelRetry, RunContext agent = Agent('test', deps_type=str) @agent.output_validator def output_validator_simple(data: str) -> str: if 'wrong' in data: raise ModelRetry('wrong response') return data @agent.output_validator async def output_validator_deps(ctx: RunContext[str], data: str) -> str: if ctx.deps in data: raise ModelRetry('wrong response') return data result = agent.run_sync('foobar', deps='spam') print(result.output) #> success (no tool calls) ``` """ self._output_validators.append(_output.OutputValidator[AgentDepsT, Any](func)) return func ```` #### tool ```python tool( func: ToolFuncContext[AgentDepsT, ToolParams], ) -> ToolFuncContext[AgentDepsT, ToolParams] ``` ```python tool( *, name: str | None = None, retries: int | None = None, prepare: ToolPrepareFunc[AgentDepsT] | None = None, docstring_format: DocstringFormat = "auto", require_parameter_descriptions: bool = False, schema_generator: type[ GenerateJsonSchema ] = GenerateToolJsonSchema, strict: bool | None = None ) -> Callable[ [ToolFuncContext[AgentDepsT, ToolParams]], ToolFuncContext[AgentDepsT, ToolParams], ] ``` ```python tool( func: ( ToolFuncContext[AgentDepsT, ToolParams] | None ) = None, /, *, name: str | None = None, retries: int | None = None, prepare: ToolPrepareFunc[AgentDepsT] | None = None, docstring_format: DocstringFormat = "auto", require_parameter_descriptions: bool = False, schema_generator: type[ GenerateJsonSchema ] = GenerateToolJsonSchema, strict: bool | None = None, ) -> Any ``` Decorator to register a tool function which takes RunContext as its first argument. Can decorate a sync or async functions. The docstring is inspected to extract both the tool description and description of each parameter, [learn more](../../tools/#function-tools-and-schema). We can't add overloads for every possible signature of tool, since the return type is a recursive union so the signature of functions decorated with `@agent.tool` is obscured. Example: ```python from pydantic_ai import Agent, RunContext agent = Agent('test', deps_type=int) @agent.tool def foobar(ctx: RunContext[int], x: int) -> int: return ctx.deps + x @agent.tool(retries=2) async def spam(ctx: RunContext[str], y: float) -> float: return ctx.deps + y result = agent.run_sync('foobar', deps=1) print(result.output) #> {"foobar":1,"spam":1.0} ``` Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `func` | `ToolFuncContext[AgentDepsT, ToolParams] | None` | The tool function to register. | `None` | | `name` | `str | None` | The name of the tool, defaults to the function name. | `None` | | `retries` | `int | None` | The number of retries to allow for this tool, defaults to the agent's default retries, which defaults to 1. | `None` | | `prepare` | `ToolPrepareFunc[AgentDepsT] | None` | custom method to prepare the tool definition for each step, return None to omit this tool from a given step. This is useful if you want to customise a tool at call time, or omit it completely from a step. See ToolPrepareFunc. | `None` | | `docstring_format` | `DocstringFormat` | The format of the docstring, see DocstringFormat. Defaults to 'auto', such that the format is inferred from the structure of the docstring. | `'auto'` | | `require_parameter_descriptions` | `bool` | If True, raise an error if a parameter description is missing. Defaults to False. | `False` | | `schema_generator` | `type[GenerateJsonSchema]` | The JSON schema generator class to use for this tool. Defaults to GenerateToolJsonSchema. | `GenerateToolJsonSchema` | | `strict` | `bool | None` | Whether to enforce JSON schema compliance (only affects OpenAI). See ToolDefinition for more info. | `None` | Source code in `pydantic_ai_slim/pydantic_ai/agent.py` ````python def tool( self, func: ToolFuncContext[AgentDepsT, ToolParams] | None = None, /, *, name: str | None = None, retries: int | None = None, prepare: ToolPrepareFunc[AgentDepsT] | None = None, docstring_format: DocstringFormat = 'auto', require_parameter_descriptions: bool = False, schema_generator: type[GenerateJsonSchema] = GenerateToolJsonSchema, strict: bool | None = None, ) -> Any: """Decorator to register a tool function which takes [`RunContext`][pydantic_ai.tools.RunContext] as its first argument. Can decorate a sync or async functions. The docstring is inspected to extract both the tool description and description of each parameter, [learn more](../tools.md#function-tools-and-schema). We can't add overloads for every possible signature of tool, since the return type is a recursive union so the signature of functions decorated with `@agent.tool` is obscured. Example: ```python from pydantic_ai import Agent, RunContext agent = Agent('test', deps_type=int) @agent.tool def foobar(ctx: RunContext[int], x: int) -> int: return ctx.deps + x @agent.tool(retries=2) async def spam(ctx: RunContext[str], y: float) -> float: return ctx.deps + y result = agent.run_sync('foobar', deps=1) print(result.output) #> {"foobar":1,"spam":1.0} ``` Args: func: The tool function to register. name: The name of the tool, defaults to the function name. retries: The number of retries to allow for this tool, defaults to the agent's default retries, which defaults to 1. prepare: custom method to prepare the tool definition for each step, return `None` to omit this tool from a given step. This is useful if you want to customise a tool at call time, or omit it completely from a step. See [`ToolPrepareFunc`][pydantic_ai.tools.ToolPrepareFunc]. docstring_format: The format of the docstring, see [`DocstringFormat`][pydantic_ai.tools.DocstringFormat]. Defaults to `'auto'`, such that the format is inferred from the structure of the docstring. require_parameter_descriptions: If True, raise an error if a parameter description is missing. Defaults to False. schema_generator: The JSON schema generator class to use for this tool. Defaults to `GenerateToolJsonSchema`. strict: Whether to enforce JSON schema compliance (only affects OpenAI). See [`ToolDefinition`][pydantic_ai.tools.ToolDefinition] for more info. """ if func is None: def tool_decorator( func_: ToolFuncContext[AgentDepsT, ToolParams], ) -> ToolFuncContext[AgentDepsT, ToolParams]: # noinspection PyTypeChecker self._register_function( func_, True, name, retries, prepare, docstring_format, require_parameter_descriptions, schema_generator, strict, ) return func_ return tool_decorator else: # noinspection PyTypeChecker self._register_function( func, True, name, retries, prepare, docstring_format, require_parameter_descriptions, schema_generator, strict, ) return func ```` #### tool_plain ```python tool_plain( func: ToolFuncPlain[ToolParams], ) -> ToolFuncPlain[ToolParams] ``` ```python tool_plain( *, name: str | None = None, retries: int | None = None, prepare: ToolPrepareFunc[AgentDepsT] | None = None, docstring_format: DocstringFormat = "auto", require_parameter_descriptions: bool = False, schema_generator: type[ GenerateJsonSchema ] = GenerateToolJsonSchema, strict: bool | None = None ) -> Callable[ [ToolFuncPlain[ToolParams]], ToolFuncPlain[ToolParams] ] ``` ```python tool_plain( func: ToolFuncPlain[ToolParams] | None = None, /, *, name: str | None = None, retries: int | None = None, prepare: ToolPrepareFunc[AgentDepsT] | None = None, docstring_format: DocstringFormat = "auto", require_parameter_descriptions: bool = False, schema_generator: type[ GenerateJsonSchema ] = GenerateToolJsonSchema, strict: bool | None = None, ) -> Any ``` Decorator to register a tool function which DOES NOT take `RunContext` as an argument. Can decorate a sync or async functions. The docstring is inspected to extract both the tool description and description of each parameter, [learn more](../../tools/#function-tools-and-schema). We can't add overloads for every possible signature of tool, since the return type is a recursive union so the signature of functions decorated with `@agent.tool` is obscured. Example: ```python from pydantic_ai import Agent, RunContext agent = Agent('test') @agent.tool def foobar(ctx: RunContext[int]) -> int: return 123 @agent.tool(retries=2) async def spam(ctx: RunContext[str]) -> float: return 3.14 result = agent.run_sync('foobar', deps=1) print(result.output) #> {"foobar":123,"spam":3.14} ``` Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `func` | `ToolFuncPlain[ToolParams] | None` | The tool function to register. | `None` | | `name` | `str | None` | The name of the tool, defaults to the function name. | `None` | | `retries` | `int | None` | The number of retries to allow for this tool, defaults to the agent's default retries, which defaults to 1. | `None` | | `prepare` | `ToolPrepareFunc[AgentDepsT] | None` | custom method to prepare the tool definition for each step, return None to omit this tool from a given step. This is useful if you want to customise a tool at call time, or omit it completely from a step. See ToolPrepareFunc. | `None` | | `docstring_format` | `DocstringFormat` | The format of the docstring, see DocstringFormat. Defaults to 'auto', such that the format is inferred from the structure of the docstring. | `'auto'` | | `require_parameter_descriptions` | `bool` | If True, raise an error if a parameter description is missing. Defaults to False. | `False` | | `schema_generator` | `type[GenerateJsonSchema]` | The JSON schema generator class to use for this tool. Defaults to GenerateToolJsonSchema. | `GenerateToolJsonSchema` | | `strict` | `bool | None` | Whether to enforce JSON schema compliance (only affects OpenAI). See ToolDefinition for more info. | `None` | Source code in `pydantic_ai_slim/pydantic_ai/agent.py` ````python def tool_plain( self, func: ToolFuncPlain[ToolParams] | None = None, /, *, name: str | None = None, retries: int | None = None, prepare: ToolPrepareFunc[AgentDepsT] | None = None, docstring_format: DocstringFormat = 'auto', require_parameter_descriptions: bool = False, schema_generator: type[GenerateJsonSchema] = GenerateToolJsonSchema, strict: bool | None = None, ) -> Any: """Decorator to register a tool function which DOES NOT take `RunContext` as an argument. Can decorate a sync or async functions. The docstring is inspected to extract both the tool description and description of each parameter, [learn more](../tools.md#function-tools-and-schema). We can't add overloads for every possible signature of tool, since the return type is a recursive union so the signature of functions decorated with `@agent.tool` is obscured. Example: ```python from pydantic_ai import Agent, RunContext agent = Agent('test') @agent.tool def foobar(ctx: RunContext[int]) -> int: return 123 @agent.tool(retries=2) async def spam(ctx: RunContext[str]) -> float: return 3.14 result = agent.run_sync('foobar', deps=1) print(result.output) #> {"foobar":123,"spam":3.14} ``` Args: func: The tool function to register. name: The name of the tool, defaults to the function name. retries: The number of retries to allow for this tool, defaults to the agent's default retries, which defaults to 1. prepare: custom method to prepare the tool definition for each step, return `None` to omit this tool from a given step. This is useful if you want to customise a tool at call time, or omit it completely from a step. See [`ToolPrepareFunc`][pydantic_ai.tools.ToolPrepareFunc]. docstring_format: The format of the docstring, see [`DocstringFormat`][pydantic_ai.tools.DocstringFormat]. Defaults to `'auto'`, such that the format is inferred from the structure of the docstring. require_parameter_descriptions: If True, raise an error if a parameter description is missing. Defaults to False. schema_generator: The JSON schema generator class to use for this tool. Defaults to `GenerateToolJsonSchema`. strict: Whether to enforce JSON schema compliance (only affects OpenAI). See [`ToolDefinition`][pydantic_ai.tools.ToolDefinition] for more info. """ if func is None: def tool_decorator(func_: ToolFuncPlain[ToolParams]) -> ToolFuncPlain[ToolParams]: # noinspection PyTypeChecker self._register_function( func_, False, name, retries, prepare, docstring_format, require_parameter_descriptions, schema_generator, strict, ) return func_ return tool_decorator else: self._register_function( func, False, name, retries, prepare, docstring_format, require_parameter_descriptions, schema_generator, strict, ) return func ```` #### is_model_request_node ```python is_model_request_node( node: AgentNode[T, S] | End[FinalResult[S]], ) -> TypeIs[ModelRequestNode[T, S]] ``` Check if the node is a `ModelRequestNode`, narrowing the type if it is. This method preserves the generic parameters while narrowing the type, unlike a direct call to `isinstance`. Source code in `pydantic_ai_slim/pydantic_ai/agent.py` ```python @staticmethod def is_model_request_node( node: _agent_graph.AgentNode[T, S] | End[result.FinalResult[S]], ) -> TypeIs[_agent_graph.ModelRequestNode[T, S]]: """Check if the node is a `ModelRequestNode`, narrowing the type if it is. This method preserves the generic parameters while narrowing the type, unlike a direct call to `isinstance`. """ return isinstance(node, _agent_graph.ModelRequestNode) ``` #### is_call_tools_node ```python is_call_tools_node( node: AgentNode[T, S] | End[FinalResult[S]], ) -> TypeIs[CallToolsNode[T, S]] ``` Check if the node is a `CallToolsNode`, narrowing the type if it is. This method preserves the generic parameters while narrowing the type, unlike a direct call to `isinstance`. Source code in `pydantic_ai_slim/pydantic_ai/agent.py` ```python @staticmethod def is_call_tools_node( node: _agent_graph.AgentNode[T, S] | End[result.FinalResult[S]], ) -> TypeIs[_agent_graph.CallToolsNode[T, S]]: """Check if the node is a `CallToolsNode`, narrowing the type if it is. This method preserves the generic parameters while narrowing the type, unlike a direct call to `isinstance`. """ return isinstance(node, _agent_graph.CallToolsNode) ``` #### is_user_prompt_node ```python is_user_prompt_node( node: AgentNode[T, S] | End[FinalResult[S]], ) -> TypeIs[UserPromptNode[T, S]] ``` Check if the node is a `UserPromptNode`, narrowing the type if it is. This method preserves the generic parameters while narrowing the type, unlike a direct call to `isinstance`. Source code in `pydantic_ai_slim/pydantic_ai/agent.py` ```python @staticmethod def is_user_prompt_node( node: _agent_graph.AgentNode[T, S] | End[result.FinalResult[S]], ) -> TypeIs[_agent_graph.UserPromptNode[T, S]]: """Check if the node is a `UserPromptNode`, narrowing the type if it is. This method preserves the generic parameters while narrowing the type, unlike a direct call to `isinstance`. """ return isinstance(node, _agent_graph.UserPromptNode) ``` #### is_end_node ```python is_end_node( node: AgentNode[T, S] | End[FinalResult[S]], ) -> TypeIs[End[FinalResult[S]]] ``` Check if the node is a `End`, narrowing the type if it is. This method preserves the generic parameters while narrowing the type, unlike a direct call to `isinstance`. Source code in `pydantic_ai_slim/pydantic_ai/agent.py` ```python @staticmethod def is_end_node( node: _agent_graph.AgentNode[T, S] | End[result.FinalResult[S]], ) -> TypeIs[End[result.FinalResult[S]]]: """Check if the node is a `End`, narrowing the type if it is. This method preserves the generic parameters while narrowing the type, unlike a direct call to `isinstance`. """ return isinstance(node, End) ``` #### run_mcp_servers ```python run_mcp_servers( model: Model | KnownModelName | str | None = None, ) -> AsyncIterator[None] ``` Run MCPServerStdios so they can be used by the agent. Returns: a context manager to start and shutdown the servers. Source code in `pydantic_ai_slim/pydantic_ai/agent.py` ```python @asynccontextmanager async def run_mcp_servers( self, model: models.Model | models.KnownModelName | str | None = None ) -> AsyncIterator[None]: """Run [`MCPServerStdio`s][pydantic_ai.mcp.MCPServerStdio] so they can be used by the agent. Returns: a context manager to start and shutdown the servers. """ try: sampling_model: models.Model | None = self._get_model(model) except exceptions.UserError: # pragma: no cover sampling_model = None exit_stack = AsyncExitStack() try: for mcp_server in self._mcp_servers: if sampling_model is not None: # pragma: no branch mcp_server.sampling_model = sampling_model await exit_stack.enter_async_context(mcp_server) yield finally: await exit_stack.aclose() ``` #### to_a2a ```python to_a2a( *, storage: Storage | None = None, broker: Broker | None = None, name: str | None = None, url: str = "http://localhost:8000", version: str = "1.0.0", description: str | None = None, provider: Provider | None = None, skills: list[Skill] | None = None, debug: bool = False, routes: Sequence[Route] | None = None, middleware: Sequence[Middleware] | None = None, exception_handlers: ( dict[Any, ExceptionHandler] | None ) = None, lifespan: Lifespan[FastA2A] | None = None ) -> FastA2A ``` Convert the agent to a FastA2A application. Example: ```python from pydantic_ai import Agent agent = Agent('openai:gpt-4o') app = agent.to_a2a() ``` The `app` is an ASGI application that can be used with any ASGI server. To run the application, you can use the following command: ```bash uvicorn app:app --host 0.0.0.0 --port 8000 ``` Source code in `pydantic_ai_slim/pydantic_ai/agent.py` ````python def to_a2a( self, *, storage: Storage | None = None, broker: Broker | None = None, # Agent card name: str | None = None, url: str = 'http://localhost:8000', version: str = '1.0.0', description: str | None = None, provider: Provider | None = None, skills: list[Skill] | None = None, # Starlette debug: bool = False, routes: Sequence[Route] | None = None, middleware: Sequence[Middleware] | None = None, exception_handlers: dict[Any, ExceptionHandler] | None = None, lifespan: Lifespan[FastA2A] | None = None, ) -> FastA2A: """Convert the agent to a FastA2A application. Example: ```python from pydantic_ai import Agent agent = Agent('openai:gpt-4o') app = agent.to_a2a() ``` The `app` is an ASGI application that can be used with any ASGI server. To run the application, you can use the following command: ```bash uvicorn app:app --host 0.0.0.0 --port 8000 ``` """ from ._a2a import agent_to_a2a return agent_to_a2a( self, storage=storage, broker=broker, name=name, url=url, version=version, description=description, provider=provider, skills=skills, debug=debug, routes=routes, middleware=middleware, exception_handlers=exception_handlers, lifespan=lifespan, ) ```` #### to_cli ```python to_cli( deps: AgentDepsT = None, prog_name: str = "pydantic-ai" ) -> None ``` Run the agent in a CLI chat interface. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `deps` | `AgentDepsT` | The dependencies to pass to the agent. | `None` | | `prog_name` | `str` | The name of the program to use for the CLI. Defaults to 'pydantic-ai'. | `'pydantic-ai'` | Example: agent_to_cli.py ```python from pydantic_ai import Agent agent = Agent('openai:gpt-4o', instructions='You always respond in Italian.') async def main(): await agent.to_cli() ``` Source code in `pydantic_ai_slim/pydantic_ai/agent.py` ````python async def to_cli(self: Self, deps: AgentDepsT = None, prog_name: str = 'pydantic-ai') -> None: """Run the agent in a CLI chat interface. Args: deps: The dependencies to pass to the agent. prog_name: The name of the program to use for the CLI. Defaults to 'pydantic-ai'. Example: ```python {title="agent_to_cli.py" test="skip"} from pydantic_ai import Agent agent = Agent('openai:gpt-4o', instructions='You always respond in Italian.') async def main(): await agent.to_cli() ``` """ from rich.console import Console from pydantic_ai._cli import run_chat await run_chat(stream=True, agent=self, deps=deps, console=Console(), code_theme='monokai', prog_name=prog_name) ```` #### to_cli_sync ```python to_cli_sync( deps: AgentDepsT = None, prog_name: str = "pydantic-ai" ) -> None ``` Run the agent in a CLI chat interface with the non-async interface. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `deps` | `AgentDepsT` | The dependencies to pass to the agent. | `None` | | `prog_name` | `str` | The name of the program to use for the CLI. Defaults to 'pydantic-ai'. | `'pydantic-ai'` | agent_to_cli_sync.py ```python from pydantic_ai import Agent agent = Agent('openai:gpt-4o', instructions='You always respond in Italian.') agent.to_cli_sync() agent.to_cli_sync(prog_name='assistant') ``` Source code in `pydantic_ai_slim/pydantic_ai/agent.py` ````python def to_cli_sync(self: Self, deps: AgentDepsT = None, prog_name: str = 'pydantic-ai') -> None: """Run the agent in a CLI chat interface with the non-async interface. Args: deps: The dependencies to pass to the agent. prog_name: The name of the program to use for the CLI. Defaults to 'pydantic-ai'. ```python {title="agent_to_cli_sync.py" test="skip"} from pydantic_ai import Agent agent = Agent('openai:gpt-4o', instructions='You always respond in Italian.') agent.to_cli_sync() agent.to_cli_sync(prog_name='assistant') ``` """ return get_event_loop().run_until_complete(self.to_cli(deps=deps, prog_name=prog_name)) ```` ### AgentRun Bases: `Generic[AgentDepsT, OutputDataT]` A stateful, async-iterable run of an Agent. You generally obtain an `AgentRun` instance by calling `async with my_agent.iter(...) as agent_run:`. Once you have an instance, you can use it to iterate through the run's nodes as they execute. When an End is reached, the run finishes and result becomes available. Example: ```python from pydantic_ai import Agent agent = Agent('openai:gpt-4o') async def main(): nodes = [] # Iterate through the run, recording each node along the way: async with agent.iter('What is the capital of France?') as agent_run: async for node in agent_run: nodes.append(node) print(nodes) ''' [ UserPromptNode( user_prompt='What is the capital of France?', instructions=None, instructions_functions=[], system_prompts=(), system_prompt_functions=[], system_prompt_dynamic_functions={}, ), ModelRequestNode( request=ModelRequest( parts=[ UserPromptPart( content='What is the capital of France?', timestamp=datetime.datetime(...), ) ] ) ), CallToolsNode( model_response=ModelResponse( parts=[TextPart(content='Paris')], usage=Usage( requests=1, request_tokens=56, response_tokens=1, total_tokens=57 ), model_name='gpt-4o', timestamp=datetime.datetime(...), ) ), End(data=FinalResult(output='Paris')), ] ''' print(agent_run.result.output) #> Paris ``` You can also manually drive the iteration using the next method for more granular control. Source code in `pydantic_ai_slim/pydantic_ai/agent.py` ````python @dataclasses.dataclass(repr=False) class AgentRun(Generic[AgentDepsT, OutputDataT]): """A stateful, async-iterable run of an [`Agent`][pydantic_ai.agent.Agent]. You generally obtain an `AgentRun` instance by calling `async with my_agent.iter(...) as agent_run:`. Once you have an instance, you can use it to iterate through the run's nodes as they execute. When an [`End`][pydantic_graph.nodes.End] is reached, the run finishes and [`result`][pydantic_ai.agent.AgentRun.result] becomes available. Example: ```python from pydantic_ai import Agent agent = Agent('openai:gpt-4o') async def main(): nodes = [] # Iterate through the run, recording each node along the way: async with agent.iter('What is the capital of France?') as agent_run: async for node in agent_run: nodes.append(node) print(nodes) ''' [ UserPromptNode( user_prompt='What is the capital of France?', instructions=None, instructions_functions=[], system_prompts=(), system_prompt_functions=[], system_prompt_dynamic_functions={}, ), ModelRequestNode( request=ModelRequest( parts=[ UserPromptPart( content='What is the capital of France?', timestamp=datetime.datetime(...), ) ] ) ), CallToolsNode( model_response=ModelResponse( parts=[TextPart(content='Paris')], usage=Usage( requests=1, request_tokens=56, response_tokens=1, total_tokens=57 ), model_name='gpt-4o', timestamp=datetime.datetime(...), ) ), End(data=FinalResult(output='Paris')), ] ''' print(agent_run.result.output) #> Paris ``` You can also manually drive the iteration using the [`next`][pydantic_ai.agent.AgentRun.next] method for more granular control. """ _graph_run: GraphRun[ _agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any], FinalResult[OutputDataT] ] @overload def _traceparent(self, *, required: Literal[False]) -> str | None: ... @overload def _traceparent(self) -> str: ... def _traceparent(self, *, required: bool = True) -> str | None: traceparent = self._graph_run._traceparent(required=False) # type: ignore[reportPrivateUsage] if traceparent is None and required: # pragma: no cover raise AttributeError('No span was created for this agent run') return traceparent @property def ctx(self) -> GraphRunContext[_agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any]]: """The current context of the agent run.""" return GraphRunContext[_agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any]]( self._graph_run.state, self._graph_run.deps ) @property def next_node( self, ) -> _agent_graph.AgentNode[AgentDepsT, OutputDataT] | End[FinalResult[OutputDataT]]: """The next node that will be run in the agent graph. This is the next node that will be used during async iteration, or if a node is not passed to `self.next(...)`. """ next_node = self._graph_run.next_node if isinstance(next_node, End): return next_node if _agent_graph.is_agent_node(next_node): return next_node raise exceptions.AgentRunError(f'Unexpected node type: {type(next_node)}') # pragma: no cover @property def result(self) -> AgentRunResult[OutputDataT] | None: """The final result of the run if it has ended, otherwise `None`. Once the run returns an [`End`][pydantic_graph.nodes.End] node, `result` is populated with an [`AgentRunResult`][pydantic_ai.agent.AgentRunResult]. """ graph_run_result = self._graph_run.result if graph_run_result is None: return None return AgentRunResult( graph_run_result.output.output, graph_run_result.output.tool_name, graph_run_result.state, self._graph_run.deps.new_message_index, self._traceparent(required=False), ) def __aiter__( self, ) -> AsyncIterator[_agent_graph.AgentNode[AgentDepsT, OutputDataT] | End[FinalResult[OutputDataT]]]: """Provide async-iteration over the nodes in the agent run.""" return self async def __anext__( self, ) -> _agent_graph.AgentNode[AgentDepsT, OutputDataT] | End[FinalResult[OutputDataT]]: """Advance to the next node automatically based on the last returned node.""" next_node = await self._graph_run.__anext__() if _agent_graph.is_agent_node(next_node): return next_node assert isinstance(next_node, End), f'Unexpected node type: {type(next_node)}' return next_node async def next( self, node: _agent_graph.AgentNode[AgentDepsT, OutputDataT], ) -> _agent_graph.AgentNode[AgentDepsT, OutputDataT] | End[FinalResult[OutputDataT]]: """Manually drive the agent run by passing in the node you want to run next. This lets you inspect or mutate the node before continuing execution, or skip certain nodes under dynamic conditions. The agent run should be stopped when you return an [`End`][pydantic_graph.nodes.End] node. Example: ```python from pydantic_ai import Agent from pydantic_graph import End agent = Agent('openai:gpt-4o') async def main(): async with agent.iter('What is the capital of France?') as agent_run: next_node = agent_run.next_node # start with the first node nodes = [next_node] while not isinstance(next_node, End): next_node = await agent_run.next(next_node) nodes.append(next_node) # Once `next_node` is an End, we've finished: print(nodes) ''' [ UserPromptNode( user_prompt='What is the capital of France?', instructions=None, instructions_functions=[], system_prompts=(), system_prompt_functions=[], system_prompt_dynamic_functions={}, ), ModelRequestNode( request=ModelRequest( parts=[ UserPromptPart( content='What is the capital of France?', timestamp=datetime.datetime(...), ) ] ) ), CallToolsNode( model_response=ModelResponse( parts=[TextPart(content='Paris')], usage=Usage( requests=1, request_tokens=56, response_tokens=1, total_tokens=57, ), model_name='gpt-4o', timestamp=datetime.datetime(...), ) ), End(data=FinalResult(output='Paris')), ] ''' print('Final result:', agent_run.result.output) #> Final result: Paris ``` Args: node: The node to run next in the graph. Returns: The next node returned by the graph logic, or an [`End`][pydantic_graph.nodes.End] node if the run has completed. """ # Note: It might be nice to expose a synchronous interface for iteration, but we shouldn't do it # on this class, or else IDEs won't warn you if you accidentally use `for` instead of `async for` to iterate. next_node = await self._graph_run.next(node) if _agent_graph.is_agent_node(next_node): return next_node assert isinstance(next_node, End), f'Unexpected node type: {type(next_node)}' return next_node def usage(self) -> _usage.Usage: """Get usage statistics for the run so far, including token usage, model requests, and so on.""" return self._graph_run.state.usage def __repr__(self) -> str: # pragma: no cover result = self._graph_run.result result_repr = '' if result is None else repr(result.output) return f'<{type(self).__name__} result={result_repr} usage={self.usage()}>' ```` #### ctx ```python ctx: GraphRunContext[ GraphAgentState, GraphAgentDeps[AgentDepsT, Any] ] ``` The current context of the agent run. #### next_node ```python next_node: ( AgentNode[AgentDepsT, OutputDataT] | End[FinalResult[OutputDataT]] ) ``` The next node that will be run in the agent graph. This is the next node that will be used during async iteration, or if a node is not passed to `self.next(...)`. #### result ```python result: AgentRunResult[OutputDataT] | None ``` The final result of the run if it has ended, otherwise `None`. Once the run returns an End node, `result` is populated with an AgentRunResult. #### __aiter__ ```python __aiter__() -> ( AsyncIterator[ AgentNode[AgentDepsT, OutputDataT] | End[FinalResult[OutputDataT]] ] ) ``` Provide async-iteration over the nodes in the agent run. Source code in `pydantic_ai_slim/pydantic_ai/agent.py` ```python def __aiter__( self, ) -> AsyncIterator[_agent_graph.AgentNode[AgentDepsT, OutputDataT] | End[FinalResult[OutputDataT]]]: """Provide async-iteration over the nodes in the agent run.""" return self ``` #### __anext__ ```python __anext__() -> ( AgentNode[AgentDepsT, OutputDataT] | End[FinalResult[OutputDataT]] ) ``` Advance to the next node automatically based on the last returned node. Source code in `pydantic_ai_slim/pydantic_ai/agent.py` ```python async def __anext__( self, ) -> _agent_graph.AgentNode[AgentDepsT, OutputDataT] | End[FinalResult[OutputDataT]]: """Advance to the next node automatically based on the last returned node.""" next_node = await self._graph_run.__anext__() if _agent_graph.is_agent_node(next_node): return next_node assert isinstance(next_node, End), f'Unexpected node type: {type(next_node)}' return next_node ``` #### next ```python next( node: AgentNode[AgentDepsT, OutputDataT], ) -> ( AgentNode[AgentDepsT, OutputDataT] | End[FinalResult[OutputDataT]] ) ``` Manually drive the agent run by passing in the node you want to run next. This lets you inspect or mutate the node before continuing execution, or skip certain nodes under dynamic conditions. The agent run should be stopped when you return an End node. Example: ```python from pydantic_ai import Agent from pydantic_graph import End agent = Agent('openai:gpt-4o') async def main(): async with agent.iter('What is the capital of France?') as agent_run: next_node = agent_run.next_node # start with the first node nodes = [next_node] while not isinstance(next_node, End): next_node = await agent_run.next(next_node) nodes.append(next_node) # Once `next_node` is an End, we've finished: print(nodes) ''' [ UserPromptNode( user_prompt='What is the capital of France?', instructions=None, instructions_functions=[], system_prompts=(), system_prompt_functions=[], system_prompt_dynamic_functions={}, ), ModelRequestNode( request=ModelRequest( parts=[ UserPromptPart( content='What is the capital of France?', timestamp=datetime.datetime(...), ) ] ) ), CallToolsNode( model_response=ModelResponse( parts=[TextPart(content='Paris')], usage=Usage( requests=1, request_tokens=56, response_tokens=1, total_tokens=57, ), model_name='gpt-4o', timestamp=datetime.datetime(...), ) ), End(data=FinalResult(output='Paris')), ] ''' print('Final result:', agent_run.result.output) #> Final result: Paris ``` Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `node` | `AgentNode[AgentDepsT, OutputDataT]` | The node to run next in the graph. | *required* | Returns: | Type | Description | | --- | --- | | `AgentNode[AgentDepsT, OutputDataT] | End[FinalResult[OutputDataT]]` | The next node returned by the graph logic, or an End node if | | `AgentNode[AgentDepsT, OutputDataT] | End[FinalResult[OutputDataT]]` | the run has completed. | Source code in `pydantic_ai_slim/pydantic_ai/agent.py` ````python async def next( self, node: _agent_graph.AgentNode[AgentDepsT, OutputDataT], ) -> _agent_graph.AgentNode[AgentDepsT, OutputDataT] | End[FinalResult[OutputDataT]]: """Manually drive the agent run by passing in the node you want to run next. This lets you inspect or mutate the node before continuing execution, or skip certain nodes under dynamic conditions. The agent run should be stopped when you return an [`End`][pydantic_graph.nodes.End] node. Example: ```python from pydantic_ai import Agent from pydantic_graph import End agent = Agent('openai:gpt-4o') async def main(): async with agent.iter('What is the capital of France?') as agent_run: next_node = agent_run.next_node # start with the first node nodes = [next_node] while not isinstance(next_node, End): next_node = await agent_run.next(next_node) nodes.append(next_node) # Once `next_node` is an End, we've finished: print(nodes) ''' [ UserPromptNode( user_prompt='What is the capital of France?', instructions=None, instructions_functions=[], system_prompts=(), system_prompt_functions=[], system_prompt_dynamic_functions={}, ), ModelRequestNode( request=ModelRequest( parts=[ UserPromptPart( content='What is the capital of France?', timestamp=datetime.datetime(...), ) ] ) ), CallToolsNode( model_response=ModelResponse( parts=[TextPart(content='Paris')], usage=Usage( requests=1, request_tokens=56, response_tokens=1, total_tokens=57, ), model_name='gpt-4o', timestamp=datetime.datetime(...), ) ), End(data=FinalResult(output='Paris')), ] ''' print('Final result:', agent_run.result.output) #> Final result: Paris ``` Args: node: The node to run next in the graph. Returns: The next node returned by the graph logic, or an [`End`][pydantic_graph.nodes.End] node if the run has completed. """ # Note: It might be nice to expose a synchronous interface for iteration, but we shouldn't do it # on this class, or else IDEs won't warn you if you accidentally use `for` instead of `async for` to iterate. next_node = await self._graph_run.next(node) if _agent_graph.is_agent_node(next_node): return next_node assert isinstance(next_node, End), f'Unexpected node type: {type(next_node)}' return next_node ```` #### usage ```python usage() -> Usage ``` Get usage statistics for the run so far, including token usage, model requests, and so on. Source code in `pydantic_ai_slim/pydantic_ai/agent.py` ```python def usage(self) -> _usage.Usage: """Get usage statistics for the run so far, including token usage, model requests, and so on.""" return self._graph_run.state.usage ``` ### AgentRunResult Bases: `Generic[OutputDataT]` The final result of an agent run. Source code in `pydantic_ai_slim/pydantic_ai/agent.py` ```python @dataclasses.dataclass class AgentRunResult(Generic[OutputDataT]): """The final result of an agent run.""" output: OutputDataT """The output data from the agent run.""" _output_tool_name: str | None = dataclasses.field(repr=False) _state: _agent_graph.GraphAgentState = dataclasses.field(repr=False) _new_message_index: int = dataclasses.field(repr=False) _traceparent_value: str | None = dataclasses.field(repr=False) @overload def _traceparent(self, *, required: Literal[False]) -> str | None: ... @overload def _traceparent(self) -> str: ... def _traceparent(self, *, required: bool = True) -> str | None: if self._traceparent_value is None and required: # pragma: no cover raise AttributeError('No span was created for this agent run') return self._traceparent_value @property @deprecated('`result.data` is deprecated, use `result.output` instead.') def data(self) -> OutputDataT: return self.output def _set_output_tool_return(self, return_content: str) -> list[_messages.ModelMessage]: """Set return content for the output tool. Useful if you want to continue the conversation and want to set the response to the output tool call. """ if not self._output_tool_name: raise ValueError('Cannot set output tool return content when the return type is `str`.') messages = deepcopy(self._state.message_history) last_message = messages[-1] for part in last_message.parts: if isinstance(part, _messages.ToolReturnPart) and part.tool_name == self._output_tool_name: part.content = return_content return messages raise LookupError(f'No tool call found with tool name {self._output_tool_name!r}.') @overload def all_messages(self, *, output_tool_return_content: str | None = None) -> list[_messages.ModelMessage]: ... @overload @deprecated('`result_tool_return_content` is deprecated, use `output_tool_return_content` instead.') def all_messages(self, *, result_tool_return_content: str | None = None) -> list[_messages.ModelMessage]: ... def all_messages( self, *, output_tool_return_content: str | None = None, result_tool_return_content: str | None = None ) -> list[_messages.ModelMessage]: """Return the history of _messages. Args: output_tool_return_content: The return content of the tool call to set in the last message. This provides a convenient way to modify the content of the output tool call if you want to continue the conversation and want to set the response to the output tool call. If `None`, the last message will not be modified. result_tool_return_content: Deprecated, use `output_tool_return_content` instead. Returns: List of messages. """ content = result.coalesce_deprecated_return_content(output_tool_return_content, result_tool_return_content) if content is not None: return self._set_output_tool_return(content) else: return self._state.message_history @overload def all_messages_json(self, *, output_tool_return_content: str | None = None) -> bytes: ... @overload @deprecated('`result_tool_return_content` is deprecated, use `output_tool_return_content` instead.') def all_messages_json(self, *, result_tool_return_content: str | None = None) -> bytes: ... def all_messages_json( self, *, output_tool_return_content: str | None = None, result_tool_return_content: str | None = None ) -> bytes: """Return all messages from [`all_messages`][pydantic_ai.agent.AgentRunResult.all_messages] as JSON bytes. Args: output_tool_return_content: The return content of the tool call to set in the last message. This provides a convenient way to modify the content of the output tool call if you want to continue the conversation and want to set the response to the output tool call. If `None`, the last message will not be modified. result_tool_return_content: Deprecated, use `output_tool_return_content` instead. Returns: JSON bytes representing the messages. """ content = result.coalesce_deprecated_return_content(output_tool_return_content, result_tool_return_content) return _messages.ModelMessagesTypeAdapter.dump_json(self.all_messages(output_tool_return_content=content)) @overload def new_messages(self, *, output_tool_return_content: str | None = None) -> list[_messages.ModelMessage]: ... @overload @deprecated('`result_tool_return_content` is deprecated, use `output_tool_return_content` instead.') def new_messages(self, *, result_tool_return_content: str | None = None) -> list[_messages.ModelMessage]: ... def new_messages( self, *, output_tool_return_content: str | None = None, result_tool_return_content: str | None = None ) -> list[_messages.ModelMessage]: """Return new messages associated with this run. Messages from older runs are excluded. Args: output_tool_return_content: The return content of the tool call to set in the last message. This provides a convenient way to modify the content of the output tool call if you want to continue the conversation and want to set the response to the output tool call. If `None`, the last message will not be modified. result_tool_return_content: Deprecated, use `output_tool_return_content` instead. Returns: List of new messages. """ content = result.coalesce_deprecated_return_content(output_tool_return_content, result_tool_return_content) return self.all_messages(output_tool_return_content=content)[self._new_message_index :] @overload def new_messages_json(self, *, output_tool_return_content: str | None = None) -> bytes: ... @overload @deprecated('`result_tool_return_content` is deprecated, use `output_tool_return_content` instead.') def new_messages_json(self, *, result_tool_return_content: str | None = None) -> bytes: ... def new_messages_json( self, *, output_tool_return_content: str | None = None, result_tool_return_content: str | None = None ) -> bytes: """Return new messages from [`new_messages`][pydantic_ai.agent.AgentRunResult.new_messages] as JSON bytes. Args: output_tool_return_content: The return content of the tool call to set in the last message. This provides a convenient way to modify the content of the output tool call if you want to continue the conversation and want to set the response to the output tool call. If `None`, the last message will not be modified. result_tool_return_content: Deprecated, use `output_tool_return_content` instead. Returns: JSON bytes representing the new messages. """ content = result.coalesce_deprecated_return_content(output_tool_return_content, result_tool_return_content) return _messages.ModelMessagesTypeAdapter.dump_json(self.new_messages(output_tool_return_content=content)) def usage(self) -> _usage.Usage: """Return the usage of the whole run.""" return self._state.usage ``` #### output ```python output: OutputDataT ``` The output data from the agent run. #### all_messages ```python all_messages( *, output_tool_return_content: str | None = None ) -> list[ModelMessage] ``` ```python all_messages( *, result_tool_return_content: str | None = None ) -> list[ModelMessage] ``` ```python all_messages( *, output_tool_return_content: str | None = None, result_tool_return_content: str | None = None ) -> list[ModelMessage] ``` Return the history of \_messages. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `output_tool_return_content` | `str | None` | The return content of the tool call to set in the last message. This provides a convenient way to modify the content of the output tool call if you want to continue the conversation and want to set the response to the output tool call. If None, the last message will not be modified. | `None` | | `result_tool_return_content` | `str | None` | Deprecated, use output_tool_return_content instead. | `None` | Returns: | Type | Description | | --- | --- | | `list[ModelMessage]` | List of messages. | Source code in `pydantic_ai_slim/pydantic_ai/agent.py` ```python def all_messages( self, *, output_tool_return_content: str | None = None, result_tool_return_content: str | None = None ) -> list[_messages.ModelMessage]: """Return the history of _messages. Args: output_tool_return_content: The return content of the tool call to set in the last message. This provides a convenient way to modify the content of the output tool call if you want to continue the conversation and want to set the response to the output tool call. If `None`, the last message will not be modified. result_tool_return_content: Deprecated, use `output_tool_return_content` instead. Returns: List of messages. """ content = result.coalesce_deprecated_return_content(output_tool_return_content, result_tool_return_content) if content is not None: return self._set_output_tool_return(content) else: return self._state.message_history ``` #### all_messages_json ```python all_messages_json( *, output_tool_return_content: str | None = None ) -> bytes ``` ```python all_messages_json( *, result_tool_return_content: str | None = None ) -> bytes ``` ```python all_messages_json( *, output_tool_return_content: str | None = None, result_tool_return_content: str | None = None ) -> bytes ``` Return all messages from all_messages as JSON bytes. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `output_tool_return_content` | `str | None` | The return content of the tool call to set in the last message. This provides a convenient way to modify the content of the output tool call if you want to continue the conversation and want to set the response to the output tool call. If None, the last message will not be modified. | `None` | | `result_tool_return_content` | `str | None` | Deprecated, use output_tool_return_content instead. | `None` | Returns: | Type | Description | | --- | --- | | `bytes` | JSON bytes representing the messages. | Source code in `pydantic_ai_slim/pydantic_ai/agent.py` ```python def all_messages_json( self, *, output_tool_return_content: str | None = None, result_tool_return_content: str | None = None ) -> bytes: """Return all messages from [`all_messages`][pydantic_ai.agent.AgentRunResult.all_messages] as JSON bytes. Args: output_tool_return_content: The return content of the tool call to set in the last message. This provides a convenient way to modify the content of the output tool call if you want to continue the conversation and want to set the response to the output tool call. If `None`, the last message will not be modified. result_tool_return_content: Deprecated, use `output_tool_return_content` instead. Returns: JSON bytes representing the messages. """ content = result.coalesce_deprecated_return_content(output_tool_return_content, result_tool_return_content) return _messages.ModelMessagesTypeAdapter.dump_json(self.all_messages(output_tool_return_content=content)) ``` #### new_messages ```python new_messages( *, output_tool_return_content: str | None = None ) -> list[ModelMessage] ``` ```python new_messages( *, result_tool_return_content: str | None = None ) -> list[ModelMessage] ``` ```python new_messages( *, output_tool_return_content: str | None = None, result_tool_return_content: str | None = None ) -> list[ModelMessage] ``` Return new messages associated with this run. Messages from older runs are excluded. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `output_tool_return_content` | `str | None` | The return content of the tool call to set in the last message. This provides a convenient way to modify the content of the output tool call if you want to continue the conversation and want to set the response to the output tool call. If None, the last message will not be modified. | `None` | | `result_tool_return_content` | `str | None` | Deprecated, use output_tool_return_content instead. | `None` | Returns: | Type | Description | | --- | --- | | `list[ModelMessage]` | List of new messages. | Source code in `pydantic_ai_slim/pydantic_ai/agent.py` ```python def new_messages( self, *, output_tool_return_content: str | None = None, result_tool_return_content: str | None = None ) -> list[_messages.ModelMessage]: """Return new messages associated with this run. Messages from older runs are excluded. Args: output_tool_return_content: The return content of the tool call to set in the last message. This provides a convenient way to modify the content of the output tool call if you want to continue the conversation and want to set the response to the output tool call. If `None`, the last message will not be modified. result_tool_return_content: Deprecated, use `output_tool_return_content` instead. Returns: List of new messages. """ content = result.coalesce_deprecated_return_content(output_tool_return_content, result_tool_return_content) return self.all_messages(output_tool_return_content=content)[self._new_message_index :] ``` #### new_messages_json ```python new_messages_json( *, output_tool_return_content: str | None = None ) -> bytes ``` ```python new_messages_json( *, result_tool_return_content: str | None = None ) -> bytes ``` ```python new_messages_json( *, output_tool_return_content: str | None = None, result_tool_return_content: str | None = None ) -> bytes ``` Return new messages from new_messages as JSON bytes. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `output_tool_return_content` | `str | None` | The return content of the tool call to set in the last message. This provides a convenient way to modify the content of the output tool call if you want to continue the conversation and want to set the response to the output tool call. If None, the last message will not be modified. | `None` | | `result_tool_return_content` | `str | None` | Deprecated, use output_tool_return_content instead. | `None` | Returns: | Type | Description | | --- | --- | | `bytes` | JSON bytes representing the new messages. | Source code in `pydantic_ai_slim/pydantic_ai/agent.py` ```python def new_messages_json( self, *, output_tool_return_content: str | None = None, result_tool_return_content: str | None = None ) -> bytes: """Return new messages from [`new_messages`][pydantic_ai.agent.AgentRunResult.new_messages] as JSON bytes. Args: output_tool_return_content: The return content of the tool call to set in the last message. This provides a convenient way to modify the content of the output tool call if you want to continue the conversation and want to set the response to the output tool call. If `None`, the last message will not be modified. result_tool_return_content: Deprecated, use `output_tool_return_content` instead. Returns: JSON bytes representing the new messages. """ content = result.coalesce_deprecated_return_content(output_tool_return_content, result_tool_return_content) return _messages.ModelMessagesTypeAdapter.dump_json(self.new_messages(output_tool_return_content=content)) ``` #### usage ```python usage() -> Usage ``` Return the usage of the whole run. Source code in `pydantic_ai_slim/pydantic_ai/agent.py` ```python def usage(self) -> _usage.Usage: """Return the usage of the whole run.""" return self._state.usage ``` ### EndStrategy ```python EndStrategy = EndStrategy ``` ### RunOutputDataT ```python RunOutputDataT = TypeVar('RunOutputDataT') ``` Type variable for the result data of a run where `output_type` was customized on the run call. ### capture_run_messages ```python capture_run_messages = capture_run_messages ``` ### InstrumentationSettings Options for instrumenting models and agents with OpenTelemetry. Used in: - `Agent(instrument=...)` - Agent.instrument_all() - InstrumentedModel See the [Debugging and Monitoring guide](https://ai.pydantic.dev/logfire/) for more info. Source code in `pydantic_ai_slim/pydantic_ai/models/instrumented.py` ```python @dataclass(init=False) class InstrumentationSettings: """Options for instrumenting models and agents with OpenTelemetry. Used in: - `Agent(instrument=...)` - [`Agent.instrument_all()`][pydantic_ai.agent.Agent.instrument_all] - [`InstrumentedModel`][pydantic_ai.models.instrumented.InstrumentedModel] See the [Debugging and Monitoring guide](https://ai.pydantic.dev/logfire/) for more info. """ tracer: Tracer = field(repr=False) event_logger: EventLogger = field(repr=False) event_mode: Literal['attributes', 'logs'] = 'attributes' include_binary_content: bool = True def __init__( self, *, event_mode: Literal['attributes', 'logs'] = 'attributes', tracer_provider: TracerProvider | None = None, meter_provider: MeterProvider | None = None, event_logger_provider: EventLoggerProvider | None = None, include_binary_content: bool = True, include_content: bool = True, ): """Create instrumentation options. Args: event_mode: The mode for emitting events. If `'attributes'`, events are attached to the span as attributes. If `'logs'`, events are emitted as OpenTelemetry log-based events. tracer_provider: The OpenTelemetry tracer provider to use. If not provided, the global tracer provider is used. Calling `logfire.configure()` sets the global tracer provider, so most users don't need this. meter_provider: The OpenTelemetry meter provider to use. If not provided, the global meter provider is used. Calling `logfire.configure()` sets the global meter provider, so most users don't need this. event_logger_provider: The OpenTelemetry event logger provider to use. If not provided, the global event logger provider is used. Calling `logfire.configure()` sets the global event logger provider, so most users don't need this. This is only used if `event_mode='logs'`. include_binary_content: Whether to include binary content in the instrumentation events. include_content: Whether to include prompts, completions, and tool call arguments and responses in the instrumentation events. """ from pydantic_ai import __version__ tracer_provider = tracer_provider or get_tracer_provider() meter_provider = meter_provider or get_meter_provider() event_logger_provider = event_logger_provider or get_event_logger_provider() scope_name = 'pydantic-ai' self.tracer = tracer_provider.get_tracer(scope_name, __version__) self.meter = meter_provider.get_meter(scope_name, __version__) self.event_logger = event_logger_provider.get_event_logger(scope_name, __version__) self.event_mode = event_mode self.include_binary_content = include_binary_content self.include_content = include_content # As specified in the OpenTelemetry GenAI metrics spec: # https://opentelemetry.io/docs/specs/semconv/gen-ai/gen-ai-metrics/#metric-gen_aiclienttokenusage tokens_histogram_kwargs = dict( name='gen_ai.client.token.usage', unit='{token}', description='Measures number of input and output tokens used', ) try: self.tokens_histogram = self.meter.create_histogram( **tokens_histogram_kwargs, explicit_bucket_boundaries_advisory=TOKEN_HISTOGRAM_BOUNDARIES, ) except TypeError: # pragma: lax no cover # Older OTel/logfire versions don't support explicit_bucket_boundaries_advisory self.tokens_histogram = self.meter.create_histogram( **tokens_histogram_kwargs, # pyright: ignore ) def messages_to_otel_events(self, messages: list[ModelMessage]) -> list[Event]: """Convert a list of model messages to OpenTelemetry events. Args: messages: The messages to convert. Returns: A list of OpenTelemetry events. """ events: list[Event] = [] instructions = InstrumentedModel._get_instructions(messages) # pyright: ignore [reportPrivateUsage] if instructions is not None: events.append(Event('gen_ai.system.message', body={'content': instructions, 'role': 'system'})) for message_index, message in enumerate(messages): message_events: list[Event] = [] if isinstance(message, ModelRequest): for part in message.parts: if hasattr(part, 'otel_event'): message_events.append(part.otel_event(self)) elif isinstance(message, ModelResponse): # pragma: no branch message_events = message.otel_events(self) for event in message_events: event.attributes = { 'gen_ai.message.index': message_index, **(event.attributes or {}), } events.extend(message_events) for event in events: event.body = InstrumentedModel.serialize_any(event.body) return events ``` #### __init__ ```python __init__( *, event_mode: Literal[ "attributes", "logs" ] = "attributes", tracer_provider: TracerProvider | None = None, meter_provider: MeterProvider | None = None, event_logger_provider: ( EventLoggerProvider | None ) = None, include_binary_content: bool = True, include_content: bool = True ) ``` Create instrumentation options. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `event_mode` | `Literal['attributes', 'logs']` | The mode for emitting events. If 'attributes', events are attached to the span as attributes. If 'logs', events are emitted as OpenTelemetry log-based events. | `'attributes'` | | `tracer_provider` | `TracerProvider | None` | The OpenTelemetry tracer provider to use. If not provided, the global tracer provider is used. Calling logfire.configure() sets the global tracer provider, so most users don't need this. | `None` | | `meter_provider` | `MeterProvider | None` | The OpenTelemetry meter provider to use. If not provided, the global meter provider is used. Calling logfire.configure() sets the global meter provider, so most users don't need this. | `None` | | `event_logger_provider` | `EventLoggerProvider | None` | The OpenTelemetry event logger provider to use. If not provided, the global event logger provider is used. Calling logfire.configure() sets the global event logger provider, so most users don't need this. This is only used if event_mode='logs'. | `None` | | `include_binary_content` | `bool` | Whether to include binary content in the instrumentation events. | `True` | | `include_content` | `bool` | Whether to include prompts, completions, and tool call arguments and responses in the instrumentation events. | `True` | Source code in `pydantic_ai_slim/pydantic_ai/models/instrumented.py` ```python def __init__( self, *, event_mode: Literal['attributes', 'logs'] = 'attributes', tracer_provider: TracerProvider | None = None, meter_provider: MeterProvider | None = None, event_logger_provider: EventLoggerProvider | None = None, include_binary_content: bool = True, include_content: bool = True, ): """Create instrumentation options. Args: event_mode: The mode for emitting events. If `'attributes'`, events are attached to the span as attributes. If `'logs'`, events are emitted as OpenTelemetry log-based events. tracer_provider: The OpenTelemetry tracer provider to use. If not provided, the global tracer provider is used. Calling `logfire.configure()` sets the global tracer provider, so most users don't need this. meter_provider: The OpenTelemetry meter provider to use. If not provided, the global meter provider is used. Calling `logfire.configure()` sets the global meter provider, so most users don't need this. event_logger_provider: The OpenTelemetry event logger provider to use. If not provided, the global event logger provider is used. Calling `logfire.configure()` sets the global event logger provider, so most users don't need this. This is only used if `event_mode='logs'`. include_binary_content: Whether to include binary content in the instrumentation events. include_content: Whether to include prompts, completions, and tool call arguments and responses in the instrumentation events. """ from pydantic_ai import __version__ tracer_provider = tracer_provider or get_tracer_provider() meter_provider = meter_provider or get_meter_provider() event_logger_provider = event_logger_provider or get_event_logger_provider() scope_name = 'pydantic-ai' self.tracer = tracer_provider.get_tracer(scope_name, __version__) self.meter = meter_provider.get_meter(scope_name, __version__) self.event_logger = event_logger_provider.get_event_logger(scope_name, __version__) self.event_mode = event_mode self.include_binary_content = include_binary_content self.include_content = include_content # As specified in the OpenTelemetry GenAI metrics spec: # https://opentelemetry.io/docs/specs/semconv/gen-ai/gen-ai-metrics/#metric-gen_aiclienttokenusage tokens_histogram_kwargs = dict( name='gen_ai.client.token.usage', unit='{token}', description='Measures number of input and output tokens used', ) try: self.tokens_histogram = self.meter.create_histogram( **tokens_histogram_kwargs, explicit_bucket_boundaries_advisory=TOKEN_HISTOGRAM_BOUNDARIES, ) except TypeError: # pragma: lax no cover # Older OTel/logfire versions don't support explicit_bucket_boundaries_advisory self.tokens_histogram = self.meter.create_histogram( **tokens_histogram_kwargs, # pyright: ignore ) ``` #### messages_to_otel_events ```python messages_to_otel_events( messages: list[ModelMessage], ) -> list[Event] ``` Convert a list of model messages to OpenTelemetry events. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `messages` | `list[ModelMessage]` | The messages to convert. | *required* | Returns: | Type | Description | | --- | --- | | `list[Event]` | A list of OpenTelemetry events. | Source code in `pydantic_ai_slim/pydantic_ai/models/instrumented.py` ```python def messages_to_otel_events(self, messages: list[ModelMessage]) -> list[Event]: """Convert a list of model messages to OpenTelemetry events. Args: messages: The messages to convert. Returns: A list of OpenTelemetry events. """ events: list[Event] = [] instructions = InstrumentedModel._get_instructions(messages) # pyright: ignore [reportPrivateUsage] if instructions is not None: events.append(Event('gen_ai.system.message', body={'content': instructions, 'role': 'system'})) for message_index, message in enumerate(messages): message_events: list[Event] = [] if isinstance(message, ModelRequest): for part in message.parts: if hasattr(part, 'otel_event'): message_events.append(part.otel_event(self)) elif isinstance(message, ModelResponse): # pragma: no branch message_events = message.otel_events(self) for event in message_events: event.attributes = { 'gen_ai.message.index': message_index, **(event.attributes or {}), } events.extend(message_events) for event in events: event.body = InstrumentedModel.serialize_any(event.body) return events ``` # `pydantic_ai.common_tools` ### duckduckgo_search_tool ```python duckduckgo_search_tool( duckduckgo_client: DDGS | None = None, max_results: int | None = None, ) ``` Creates a DuckDuckGo search tool. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `duckduckgo_client` | `DDGS | None` | The DuckDuckGo search client. | `None` | | `max_results` | `int | None` | The maximum number of results. If None, returns results only from the first response. | `None` | Source code in `pydantic_ai_slim/pydantic_ai/common_tools/duckduckgo.py` ```python def duckduckgo_search_tool(duckduckgo_client: DDGS | None = None, max_results: int | None = None): """Creates a DuckDuckGo search tool. Args: duckduckgo_client: The DuckDuckGo search client. max_results: The maximum number of results. If None, returns results only from the first response. """ return Tool( DuckDuckGoSearchTool(client=duckduckgo_client or DDGS(), max_results=max_results).__call__, name='duckduckgo_search', description='Searches DuckDuckGo for the given query and returns the results.', ) ``` ### tavily_search_tool ```python tavily_search_tool(api_key: str) ``` Creates a Tavily search tool. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `api_key` | `str` | The Tavily API key. You can get one by signing up at https://app.tavily.com/home. | *required* | Source code in `pydantic_ai_slim/pydantic_ai/common_tools/tavily.py` ```python def tavily_search_tool(api_key: str): """Creates a Tavily search tool. Args: api_key: The Tavily API key. You can get one by signing up at [https://app.tavily.com/home](https://app.tavily.com/home). """ return Tool( TavilySearchTool(client=AsyncTavilyClient(api_key)).__call__, name='tavily_search', description='Searches Tavily for the given query and returns the results.', ) ``` # `pydantic_ai.direct` Methods for making imperative requests to language models with minimal abstraction. These methods allow you to make requests to LLMs where the only abstraction is input and output schema translation so you can use all models with the same API. These methods are thin wrappers around Model implementations. ### model_request ```python model_request( model: Model | KnownModelName | str, messages: list[ModelMessage], *, model_settings: ModelSettings | None = None, model_request_parameters: ( ModelRequestParameters | None ) = None, instrument: InstrumentationSettings | bool | None = None ) -> ModelResponse ``` Make a non-streamed request to a model. model_request_example.py ```py from pydantic_ai.direct import model_request from pydantic_ai.messages import ModelRequest async def main(): model_response = await model_request( 'anthropic:claude-3-5-haiku-latest', [ModelRequest.user_text_prompt('What is the capital of France?')] # (1)! ) print(model_response) ''' ModelResponse( parts=[TextPart(content='Paris')], usage=Usage(requests=1, request_tokens=56, response_tokens=1, total_tokens=57), model_name='claude-3-5-haiku-latest', timestamp=datetime.datetime(...), ) ''' ``` 1. See ModelRequest.user_text_prompt for details. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `model` | `Model | KnownModelName | str` | The model to make a request to. We allow str here since the actual list of allowed models changes frequently. | *required* | | `messages` | `list[ModelMessage]` | Messages to send to the model | *required* | | `model_settings` | `ModelSettings | None` | optional model settings | `None` | | `model_request_parameters` | `ModelRequestParameters | None` | optional model request parameters | `None` | | `instrument` | `InstrumentationSettings | bool | None` | Whether to instrument the request with OpenTelemetry/Logfire, if None the value from logfire.instrument_pydantic_ai is used. | `None` | Returns: | Type | Description | | --- | --- | | `ModelResponse` | The model response and token usage associated with the request. | Source code in `pydantic_ai_slim/pydantic_ai/direct.py` ````python async def model_request( model: models.Model | models.KnownModelName | str, messages: list[messages.ModelMessage], *, model_settings: settings.ModelSettings | None = None, model_request_parameters: models.ModelRequestParameters | None = None, instrument: instrumented_models.InstrumentationSettings | bool | None = None, ) -> messages.ModelResponse: """Make a non-streamed request to a model. ```py title="model_request_example.py" from pydantic_ai.direct import model_request from pydantic_ai.messages import ModelRequest async def main(): model_response = await model_request( 'anthropic:claude-3-5-haiku-latest', [ModelRequest.user_text_prompt('What is the capital of France?')] # (1)! ) print(model_response) ''' ModelResponse( parts=[TextPart(content='Paris')], usage=Usage(requests=1, request_tokens=56, response_tokens=1, total_tokens=57), model_name='claude-3-5-haiku-latest', timestamp=datetime.datetime(...), ) ''' ``` 1. See [`ModelRequest.user_text_prompt`][pydantic_ai.messages.ModelRequest.user_text_prompt] for details. Args: model: The model to make a request to. We allow `str` here since the actual list of allowed models changes frequently. messages: Messages to send to the model model_settings: optional model settings model_request_parameters: optional model request parameters instrument: Whether to instrument the request with OpenTelemetry/Logfire, if `None` the value from [`logfire.instrument_pydantic_ai`][logfire.Logfire.instrument_pydantic_ai] is used. Returns: The model response and token usage associated with the request. """ model_instance = _prepare_model(model, instrument) return await model_instance.request( messages, model_settings, model_instance.customize_request_parameters(model_request_parameters or models.ModelRequestParameters()), ) ```` ### model_request_sync ```python model_request_sync( model: Model | KnownModelName | str, messages: list[ModelMessage], *, model_settings: ModelSettings | None = None, model_request_parameters: ( ModelRequestParameters | None ) = None, instrument: InstrumentationSettings | bool | None = None ) -> ModelResponse ``` Make a Synchronous, non-streamed request to a model. This is a convenience method that wraps model_request with `loop.run_until_complete(...)`. You therefore can't use this method inside async code or if there's an active event loop. model_request_sync_example.py ```py from pydantic_ai.direct import model_request_sync from pydantic_ai.messages import ModelRequest model_response = model_request_sync( 'anthropic:claude-3-5-haiku-latest', [ModelRequest.user_text_prompt('What is the capital of France?')] # (1)! ) print(model_response) ''' ModelResponse( parts=[TextPart(content='Paris')], usage=Usage(requests=1, request_tokens=56, response_tokens=1, total_tokens=57), model_name='claude-3-5-haiku-latest', timestamp=datetime.datetime(...), ) ''' ``` 1. See ModelRequest.user_text_prompt for details. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `model` | `Model | KnownModelName | str` | The model to make a request to. We allow str here since the actual list of allowed models changes frequently. | *required* | | `messages` | `list[ModelMessage]` | Messages to send to the model | *required* | | `model_settings` | `ModelSettings | None` | optional model settings | `None` | | `model_request_parameters` | `ModelRequestParameters | None` | optional model request parameters | `None` | | `instrument` | `InstrumentationSettings | bool | None` | Whether to instrument the request with OpenTelemetry/Logfire, if None the value from logfire.instrument_pydantic_ai is used. | `None` | Returns: | Type | Description | | --- | --- | | `ModelResponse` | The model response and token usage associated with the request. | Source code in `pydantic_ai_slim/pydantic_ai/direct.py` ````python def model_request_sync( model: models.Model | models.KnownModelName | str, messages: list[messages.ModelMessage], *, model_settings: settings.ModelSettings | None = None, model_request_parameters: models.ModelRequestParameters | None = None, instrument: instrumented_models.InstrumentationSettings | bool | None = None, ) -> messages.ModelResponse: """Make a Synchronous, non-streamed request to a model. This is a convenience method that wraps [`model_request`][pydantic_ai.direct.model_request] with `loop.run_until_complete(...)`. You therefore can't use this method inside async code or if there's an active event loop. ```py title="model_request_sync_example.py" from pydantic_ai.direct import model_request_sync from pydantic_ai.messages import ModelRequest model_response = model_request_sync( 'anthropic:claude-3-5-haiku-latest', [ModelRequest.user_text_prompt('What is the capital of France?')] # (1)! ) print(model_response) ''' ModelResponse( parts=[TextPart(content='Paris')], usage=Usage(requests=1, request_tokens=56, response_tokens=1, total_tokens=57), model_name='claude-3-5-haiku-latest', timestamp=datetime.datetime(...), ) ''' ``` 1. See [`ModelRequest.user_text_prompt`][pydantic_ai.messages.ModelRequest.user_text_prompt] for details. Args: model: The model to make a request to. We allow `str` here since the actual list of allowed models changes frequently. messages: Messages to send to the model model_settings: optional model settings model_request_parameters: optional model request parameters instrument: Whether to instrument the request with OpenTelemetry/Logfire, if `None` the value from [`logfire.instrument_pydantic_ai`][logfire.Logfire.instrument_pydantic_ai] is used. Returns: The model response and token usage associated with the request. """ return _get_event_loop().run_until_complete( model_request( model, messages, model_settings=model_settings, model_request_parameters=model_request_parameters, instrument=instrument, ) ) ```` ### model_request_stream ```python model_request_stream( model: Model | KnownModelName | str, messages: list[ModelMessage], *, model_settings: ModelSettings | None = None, model_request_parameters: ( ModelRequestParameters | None ) = None, instrument: InstrumentationSettings | bool | None = None ) -> AbstractAsyncContextManager[StreamedResponse] ``` Make a streamed async request to a model. model_request_stream_example.py ```py from pydantic_ai.direct import model_request_stream from pydantic_ai.messages import ModelRequest async def main(): messages = [ModelRequest.user_text_prompt('Who was Albert Einstein?')] # (1)! async with model_request_stream('openai:gpt-4.1-mini', messages) as stream: chunks = [] async for chunk in stream: chunks.append(chunk) print(chunks) ''' [ PartStartEvent(index=0, part=TextPart(content='Albert Einstein was ')), PartDeltaEvent( index=0, delta=TextPartDelta(content_delta='a German-born theoretical ') ), PartDeltaEvent(index=0, delta=TextPartDelta(content_delta='physicist.')), ] ''' ``` 1. See ModelRequest.user_text_prompt for details. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `model` | `Model | KnownModelName | str` | The model to make a request to. We allow str here since the actual list of allowed models changes frequently. | *required* | | `messages` | `list[ModelMessage]` | Messages to send to the model | *required* | | `model_settings` | `ModelSettings | None` | optional model settings | `None` | | `model_request_parameters` | `ModelRequestParameters | None` | optional model request parameters | `None` | | `instrument` | `InstrumentationSettings | bool | None` | Whether to instrument the request with OpenTelemetry/Logfire, if None the value from logfire.instrument_pydantic_ai is used. | `None` | Returns: | Type | Description | | --- | --- | | `AbstractAsyncContextManager[StreamedResponse]` | A stream response async context manager. | Source code in `pydantic_ai_slim/pydantic_ai/direct.py` ````python def model_request_stream( model: models.Model | models.KnownModelName | str, messages: list[messages.ModelMessage], *, model_settings: settings.ModelSettings | None = None, model_request_parameters: models.ModelRequestParameters | None = None, instrument: instrumented_models.InstrumentationSettings | bool | None = None, ) -> AbstractAsyncContextManager[models.StreamedResponse]: """Make a streamed async request to a model. ```py {title="model_request_stream_example.py"} from pydantic_ai.direct import model_request_stream from pydantic_ai.messages import ModelRequest async def main(): messages = [ModelRequest.user_text_prompt('Who was Albert Einstein?')] # (1)! async with model_request_stream('openai:gpt-4.1-mini', messages) as stream: chunks = [] async for chunk in stream: chunks.append(chunk) print(chunks) ''' [ PartStartEvent(index=0, part=TextPart(content='Albert Einstein was ')), PartDeltaEvent( index=0, delta=TextPartDelta(content_delta='a German-born theoretical ') ), PartDeltaEvent(index=0, delta=TextPartDelta(content_delta='physicist.')), ] ''' ``` 1. See [`ModelRequest.user_text_prompt`][pydantic_ai.messages.ModelRequest.user_text_prompt] for details. Args: model: The model to make a request to. We allow `str` here since the actual list of allowed models changes frequently. messages: Messages to send to the model model_settings: optional model settings model_request_parameters: optional model request parameters instrument: Whether to instrument the request with OpenTelemetry/Logfire, if `None` the value from [`logfire.instrument_pydantic_ai`][logfire.Logfire.instrument_pydantic_ai] is used. Returns: A [stream response][pydantic_ai.models.StreamedResponse] async context manager. """ model_instance = _prepare_model(model, instrument) return model_instance.request_stream( messages, model_settings, model_instance.customize_request_parameters(model_request_parameters or models.ModelRequestParameters()), ) ```` ### model_request_stream_sync ```python model_request_stream_sync( model: Model | KnownModelName | str, messages: list[ModelMessage], *, model_settings: ModelSettings | None = None, model_request_parameters: ( ModelRequestParameters | None ) = None, instrument: InstrumentationSettings | bool | None = None ) -> StreamedResponseSync ``` Make a streamed synchronous request to a model. This is the synchronous version of model_request_stream. It uses threading to run the asynchronous stream in the background while providing a synchronous iterator interface. model_request_stream_sync_example.py ```py from pydantic_ai.direct import model_request_stream_sync from pydantic_ai.messages import ModelRequest messages = [ModelRequest.user_text_prompt('Who was Albert Einstein?')] with model_request_stream_sync('openai:gpt-4.1-mini', messages) as stream: chunks = [] for chunk in stream: chunks.append(chunk) print(chunks) ''' [ PartStartEvent(index=0, part=TextPart(content='Albert Einstein was ')), PartDeltaEvent( index=0, delta=TextPartDelta(content_delta='a German-born theoretical ') ), PartDeltaEvent(index=0, delta=TextPartDelta(content_delta='physicist.')), ] ''' ``` Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `model` | `Model | KnownModelName | str` | The model to make a request to. We allow str here since the actual list of allowed models changes frequently. | *required* | | `messages` | `list[ModelMessage]` | Messages to send to the model | *required* | | `model_settings` | `ModelSettings | None` | optional model settings | `None` | | `model_request_parameters` | `ModelRequestParameters | None` | optional model request parameters | `None` | | `instrument` | `InstrumentationSettings | bool | None` | Whether to instrument the request with OpenTelemetry/Logfire, if None the value from logfire.instrument_pydantic_ai is used. | `None` | Returns: | Type | Description | | --- | --- | | `StreamedResponseSync` | A sync stream response context manager. | Source code in `pydantic_ai_slim/pydantic_ai/direct.py` ````python def model_request_stream_sync( model: models.Model | models.KnownModelName | str, messages: list[messages.ModelMessage], *, model_settings: settings.ModelSettings | None = None, model_request_parameters: models.ModelRequestParameters | None = None, instrument: instrumented_models.InstrumentationSettings | bool | None = None, ) -> StreamedResponseSync: """Make a streamed synchronous request to a model. This is the synchronous version of [`model_request_stream`][pydantic_ai.direct.model_request_stream]. It uses threading to run the asynchronous stream in the background while providing a synchronous iterator interface. ```py {title="model_request_stream_sync_example.py"} from pydantic_ai.direct import model_request_stream_sync from pydantic_ai.messages import ModelRequest messages = [ModelRequest.user_text_prompt('Who was Albert Einstein?')] with model_request_stream_sync('openai:gpt-4.1-mini', messages) as stream: chunks = [] for chunk in stream: chunks.append(chunk) print(chunks) ''' [ PartStartEvent(index=0, part=TextPart(content='Albert Einstein was ')), PartDeltaEvent( index=0, delta=TextPartDelta(content_delta='a German-born theoretical ') ), PartDeltaEvent(index=0, delta=TextPartDelta(content_delta='physicist.')), ] ''' ``` Args: model: The model to make a request to. We allow `str` here since the actual list of allowed models changes frequently. messages: Messages to send to the model model_settings: optional model settings model_request_parameters: optional model request parameters instrument: Whether to instrument the request with OpenTelemetry/Logfire, if `None` the value from [`logfire.instrument_pydantic_ai`][logfire.Logfire.instrument_pydantic_ai] is used. Returns: A [sync stream response][pydantic_ai.direct.StreamedResponseSync] context manager. """ async_stream_cm = model_request_stream( model=model, messages=messages, model_settings=model_settings, model_request_parameters=model_request_parameters, instrument=instrument, ) return StreamedResponseSync(async_stream_cm) ```` ### StreamedResponseSync Synchronous wrapper to async streaming responses by running the async producer in a background thread and providing a synchronous iterator. This class must be used as a context manager with the `with` statement. Source code in `pydantic_ai_slim/pydantic_ai/direct.py` ```python @dataclass class StreamedResponseSync: """Synchronous wrapper to async streaming responses by running the async producer in a background thread and providing a synchronous iterator. This class must be used as a context manager with the `with` statement. """ _async_stream_cm: AbstractAsyncContextManager[StreamedResponse] _queue: queue.Queue[messages.ModelResponseStreamEvent | Exception | None] = field( default_factory=queue.Queue, init=False ) _thread: threading.Thread | None = field(default=None, init=False) _stream_response: StreamedResponse | None = field(default=None, init=False) _exception: Exception | None = field(default=None, init=False) _context_entered: bool = field(default=False, init=False) _stream_ready: threading.Event = field(default_factory=threading.Event, init=False) def __enter__(self) -> StreamedResponseSync: self._context_entered = True self._start_producer() return self def __exit__( self, _exc_type: type[BaseException] | None, _exc_val: BaseException | None, _exc_tb: TracebackType | None, ) -> None: self._cleanup() def __iter__(self) -> Iterator[messages.ModelResponseStreamEvent]: """Stream the response as an iterable of [`ModelResponseStreamEvent`][pydantic_ai.messages.ModelResponseStreamEvent]s.""" self._check_context_manager_usage() while True: item = self._queue.get() if item is None: # End of stream break elif isinstance(item, Exception): raise item else: yield item def __repr__(self) -> str: if self._stream_response: return repr(self._stream_response) else: return f'{self.__class__.__name__}(context_entered={self._context_entered})' __str__ = __repr__ def _check_context_manager_usage(self) -> None: if not self._context_entered: raise RuntimeError( 'StreamedResponseSync must be used as a context manager. ' 'Use: `with model_request_stream_sync(...) as stream:`' ) def _ensure_stream_ready(self) -> StreamedResponse: self._check_context_manager_usage() if self._stream_response is None: # Wait for the background thread to signal that the stream is ready if not self._stream_ready.wait(timeout=STREAM_INITIALIZATION_TIMEOUT): raise RuntimeError('Stream failed to initialize within timeout') if self._stream_response is None: # pragma: no cover raise RuntimeError('Stream failed to initialize') return self._stream_response def _start_producer(self): self._thread = threading.Thread(target=self._async_producer, daemon=True) self._thread.start() def _async_producer(self): async def _consume_async_stream(): try: async with self._async_stream_cm as stream: self._stream_response = stream # Signal that the stream is ready self._stream_ready.set() async for event in stream: self._queue.put(event) except Exception as e: # Signal ready even on error so waiting threads don't hang self._stream_ready.set() self._queue.put(e) finally: self._queue.put(None) # Signal end _get_event_loop().run_until_complete(_consume_async_stream()) def _cleanup(self): if self._thread and self._thread.is_alive(): self._thread.join() def get(self) -> messages.ModelResponse: """Build a ModelResponse from the data received from the stream so far.""" return self._ensure_stream_ready().get() def usage(self) -> Usage: """Get the usage of the response so far.""" return self._ensure_stream_ready().usage() @property def model_name(self) -> str: """Get the model name of the response.""" return self._ensure_stream_ready().model_name @property def timestamp(self) -> datetime: """Get the timestamp of the response.""" return self._ensure_stream_ready().timestamp ``` #### __iter__ ```python __iter__() -> Iterator[ModelResponseStreamEvent] ``` Stream the response as an iterable of ModelResponseStreamEvents. Source code in `pydantic_ai_slim/pydantic_ai/direct.py` ```python def __iter__(self) -> Iterator[messages.ModelResponseStreamEvent]: """Stream the response as an iterable of [`ModelResponseStreamEvent`][pydantic_ai.messages.ModelResponseStreamEvent]s.""" self._check_context_manager_usage() while True: item = self._queue.get() if item is None: # End of stream break elif isinstance(item, Exception): raise item else: yield item ``` #### get ```python get() -> ModelResponse ``` Build a ModelResponse from the data received from the stream so far. Source code in `pydantic_ai_slim/pydantic_ai/direct.py` ```python def get(self) -> messages.ModelResponse: """Build a ModelResponse from the data received from the stream so far.""" return self._ensure_stream_ready().get() ``` #### usage ```python usage() -> Usage ``` Get the usage of the response so far. Source code in `pydantic_ai_slim/pydantic_ai/direct.py` ```python def usage(self) -> Usage: """Get the usage of the response so far.""" return self._ensure_stream_ready().usage() ``` #### model_name ```python model_name: str ``` Get the model name of the response. #### timestamp ```python timestamp: datetime ``` Get the timestamp of the response. # `pydantic_ai.exceptions` ### ModelRetry Bases: `Exception` Exception raised when a tool function should be retried. The agent will return the message to the model and ask it to try calling the function/tool again. Source code in `pydantic_ai_slim/pydantic_ai/exceptions.py` ```python class ModelRetry(Exception): """Exception raised when a tool function should be retried. The agent will return the message to the model and ask it to try calling the function/tool again. """ message: str """The message to return to the model.""" def __init__(self, message: str): self.message = message super().__init__(message) ``` #### message ```python message: str = message ``` The message to return to the model. ### UserError Bases: `RuntimeError` Error caused by a usage mistake by the application developer — You! Source code in `pydantic_ai_slim/pydantic_ai/exceptions.py` ```python class UserError(RuntimeError): """Error caused by a usage mistake by the application developer — You!""" message: str """Description of the mistake.""" def __init__(self, message: str): self.message = message super().__init__(message) ``` #### message ```python message: str = message ``` Description of the mistake. ### AgentRunError Bases: `RuntimeError` Base class for errors occurring during an agent run. Source code in `pydantic_ai_slim/pydantic_ai/exceptions.py` ```python class AgentRunError(RuntimeError): """Base class for errors occurring during an agent run.""" message: str """The error message.""" def __init__(self, message: str): self.message = message super().__init__(message) def __str__(self) -> str: return self.message ``` #### message ```python message: str = message ``` The error message. ### UsageLimitExceeded Bases: `AgentRunError` Error raised when a Model's usage exceeds the specified limits. Source code in `pydantic_ai_slim/pydantic_ai/exceptions.py` ```python class UsageLimitExceeded(AgentRunError): """Error raised when a Model's usage exceeds the specified limits.""" ``` ### UnexpectedModelBehavior Bases: `AgentRunError` Error caused by unexpected Model behavior, e.g. an unexpected response code. Source code in `pydantic_ai_slim/pydantic_ai/exceptions.py` ```python class UnexpectedModelBehavior(AgentRunError): """Error caused by unexpected Model behavior, e.g. an unexpected response code.""" message: str """Description of the unexpected behavior.""" body: str | None """The body of the response, if available.""" def __init__(self, message: str, body: str | None = None): self.message = message if body is None: self.body: str | None = None else: try: self.body = json.dumps(json.loads(body), indent=2) except ValueError: self.body = body super().__init__(message) def __str__(self) -> str: if self.body: return f'{self.message}, body:\n{self.body}' else: return self.message ``` #### message ```python message: str = message ``` Description of the unexpected behavior. #### body ```python body: str | None = dumps(loads(body), indent=2) ``` The body of the response, if available. ### ModelHTTPError Bases: `AgentRunError` Raised when an model provider response has a status code of 4xx or 5xx. Source code in `pydantic_ai_slim/pydantic_ai/exceptions.py` ```python class ModelHTTPError(AgentRunError): """Raised when an model provider response has a status code of 4xx or 5xx.""" status_code: int """The HTTP status code returned by the API.""" model_name: str """The name of the model associated with the error.""" body: object | None """The body of the response, if available.""" message: str """The error message with the status code and response body, if available.""" def __init__(self, status_code: int, model_name: str, body: object | None = None): self.status_code = status_code self.model_name = model_name self.body = body message = f'status_code: {status_code}, model_name: {model_name}, body: {body}' super().__init__(message) ``` #### message ```python message: str ``` The error message with the status code and response body, if available. #### status_code ```python status_code: int = status_code ``` The HTTP status code returned by the API. #### model_name ```python model_name: str = model_name ``` The name of the model associated with the error. #### body ```python body: object | None = body ``` The body of the response, if available. ### FallbackExceptionGroup Bases: `ExceptionGroup` A group of exceptions that can be raised when all fallback models fail. Source code in `pydantic_ai_slim/pydantic_ai/exceptions.py` ```python class FallbackExceptionGroup(ExceptionGroup): """A group of exceptions that can be raised when all fallback models fail.""" ``` # `fasta2a` ### FastA2A Bases: `Starlette` The main class for the FastA2A library. Source code in `fasta2a/fasta2a/applications.py` ```python class FastA2A(Starlette): """The main class for the FastA2A library.""" def __init__( self, *, storage: Storage, broker: Broker, # Agent card name: str | None = None, url: str = 'http://localhost:8000', version: str = '1.0.0', description: str | None = None, provider: Provider | None = None, skills: list[Skill] | None = None, # Starlette debug: bool = False, routes: Sequence[Route] | None = None, middleware: Sequence[Middleware] | None = None, exception_handlers: dict[Any, ExceptionHandler] | None = None, lifespan: Lifespan[FastA2A] | None = None, ): if lifespan is None: lifespan = _default_lifespan super().__init__( debug=debug, routes=routes, middleware=middleware, exception_handlers=exception_handlers, lifespan=lifespan, ) self.name = name or 'Agent' self.url = url self.version = version self.description = description self.provider = provider self.skills = skills or [] # NOTE: For now, I don't think there's any reason to support any other input/output modes. self.default_input_modes = ['application/json'] self.default_output_modes = ['application/json'] self.task_manager = TaskManager(broker=broker, storage=storage) # Setup self._agent_card_json_schema: bytes | None = None self.router.add_route('/.well-known/agent.json', self._agent_card_endpoint, methods=['HEAD', 'GET', 'OPTIONS']) self.router.add_route('/', self._agent_run_endpoint, methods=['POST']) async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: if scope['type'] == 'http' and not self.task_manager.is_running: raise RuntimeError('TaskManager was not properly initialized.') await super().__call__(scope, receive, send) async def _agent_card_endpoint(self, request: Request) -> Response: if self._agent_card_json_schema is None: agent_card = AgentCard( name=self.name, url=self.url, version=self.version, skills=self.skills, default_input_modes=self.default_input_modes, default_output_modes=self.default_output_modes, capabilities=Capabilities(streaming=False, push_notifications=False, state_transition_history=False), authentication=Authentication(schemes=[]), ) if self.description is not None: agent_card['description'] = self.description if self.provider is not None: agent_card['provider'] = self.provider self._agent_card_json_schema = agent_card_ta.dump_json(agent_card, by_alias=True) return Response(content=self._agent_card_json_schema, media_type='application/json') async def _agent_run_endpoint(self, request: Request) -> Response: """This is the main endpoint for the A2A server. Although the specification allows freedom of choice and implementation, I'm pretty sure about some decisions. 1. The server will always either send a "submitted" or a "failed" on `tasks/send`. Never a "completed" on the first message. 2. There are three possible ends for the task: 2.1. The task was "completed" successfully. 2.2. The task was "canceled". 2.3. The task "failed". 3. The server will send a "working" on the first chunk on `tasks/pushNotification/get`. """ data = await request.body() a2a_request = a2a_request_ta.validate_json(data) if a2a_request['method'] == 'tasks/send': jsonrpc_response = await self.task_manager.send_task(a2a_request) elif a2a_request['method'] == 'tasks/get': jsonrpc_response = await self.task_manager.get_task(a2a_request) elif a2a_request['method'] == 'tasks/cancel': jsonrpc_response = await self.task_manager.cancel_task(a2a_request) else: raise NotImplementedError(f'Method {a2a_request["method"]} not implemented.') return Response( content=a2a_response_ta.dump_json(jsonrpc_response, by_alias=True), media_type='application/json' ) ``` ### Broker Bases: `ABC` The broker class is in charge of scheduling the tasks. The HTTP server uses the broker to schedule tasks. The simple implementation is the `InMemoryBroker`, which is the broker that runs the tasks in the same process as the HTTP server. That said, this class can be extended to support remote workers. Source code in `fasta2a/fasta2a/broker.py` ```python @dataclass class Broker(ABC): """The broker class is in charge of scheduling the tasks. The HTTP server uses the broker to schedule tasks. The simple implementation is the `InMemoryBroker`, which is the broker that runs the tasks in the same process as the HTTP server. That said, this class can be extended to support remote workers. """ @abstractmethod async def run_task(self, params: TaskSendParams) -> None: """Send a task to be executed by the worker.""" raise NotImplementedError('send_run_task is not implemented yet.') @abstractmethod async def cancel_task(self, params: TaskIdParams) -> None: """Cancel a task.""" raise NotImplementedError('send_cancel_task is not implemented yet.') @abstractmethod async def __aenter__(self) -> Self: ... @abstractmethod async def __aexit__(self, exc_type: Any, exc_value: Any, traceback: Any): ... @abstractmethod def receive_task_operations(self) -> AsyncIterator[TaskOperation]: """Receive task operations from the broker. On a multi-worker setup, the broker will need to round-robin the task operations between the workers. """ ``` #### run_task ```python run_task(params: TaskSendParams) -> None ``` Send a task to be executed by the worker. Source code in `fasta2a/fasta2a/broker.py` ```python @abstractmethod async def run_task(self, params: TaskSendParams) -> None: """Send a task to be executed by the worker.""" raise NotImplementedError('send_run_task is not implemented yet.') ``` #### cancel_task ```python cancel_task(params: TaskIdParams) -> None ``` Cancel a task. Source code in `fasta2a/fasta2a/broker.py` ```python @abstractmethod async def cancel_task(self, params: TaskIdParams) -> None: """Cancel a task.""" raise NotImplementedError('send_cancel_task is not implemented yet.') ``` #### receive_task_operations ```python receive_task_operations() -> AsyncIterator[TaskOperation] ``` Receive task operations from the broker. On a multi-worker setup, the broker will need to round-robin the task operations between the workers. Source code in `fasta2a/fasta2a/broker.py` ```python @abstractmethod def receive_task_operations(self) -> AsyncIterator[TaskOperation]: """Receive task operations from the broker. On a multi-worker setup, the broker will need to round-robin the task operations between the workers. """ ``` ### Skill Bases: `TypedDict` Skills are a unit of capability that an agent can perform. Source code in `fasta2a/fasta2a/schema.py` ```python @pydantic.with_config({'alias_generator': to_camel}) class Skill(TypedDict): """Skills are a unit of capability that an agent can perform.""" id: str """A unique identifier for the skill.""" name: str """Human readable name of the skill.""" description: str """A human-readable description of the skill. It will be used by the client or a human as a hint to understand the skill. """ tags: list[str] """Set of tag-words describing classes of capabilities for this specific skill. Examples: "cooking", "customer support", "billing". """ examples: NotRequired[list[str]] """The set of example scenarios that the skill can perform. Will be used by the client as a hint to understand how the skill can be used. (e.g. "I need a recipe for bread") """ input_modes: list[str] """Supported mime types for input data.""" output_modes: list[str] """Supported mime types for output data.""" ``` #### id ```python id: str ``` A unique identifier for the skill. #### name ```python name: str ``` Human readable name of the skill. #### description ```python description: str ``` A human-readable description of the skill. It will be used by the client or a human as a hint to understand the skill. #### tags ```python tags: list[str] ``` Set of tag-words describing classes of capabilities for this specific skill. Examples: "cooking", "customer support", "billing". #### examples ```python examples: NotRequired[list[str]] ``` The set of example scenarios that the skill can perform. Will be used by the client as a hint to understand how the skill can be used. (e.g. "I need a recipe for bread") #### input_modes ```python input_modes: list[str] ``` Supported mime types for input data. #### output_modes ```python output_modes: list[str] ``` Supported mime types for output data. ### Storage Bases: `ABC` A storage to retrieve and save tasks. The storage is used to update the status of a task and to save the result of a task. Source code in `fasta2a/fasta2a/storage.py` ```python class Storage(ABC): """A storage to retrieve and save tasks. The storage is used to update the status of a task and to save the result of a task. """ @abstractmethod async def load_task(self, task_id: str, history_length: int | None = None) -> Task | None: """Load a task from storage. If the task is not found, return None. """ @abstractmethod async def submit_task(self, task_id: str, session_id: str, message: Message) -> Task: """Submit a task to storage.""" @abstractmethod async def update_task( self, task_id: str, state: TaskState, message: Message | None = None, artifacts: list[Artifact] | None = None, ) -> Task: """Update the state of a task.""" ``` #### load_task ```python load_task( task_id: str, history_length: int | None = None ) -> Task | None ``` Load a task from storage. If the task is not found, return None. Source code in `fasta2a/fasta2a/storage.py` ```python @abstractmethod async def load_task(self, task_id: str, history_length: int | None = None) -> Task | None: """Load a task from storage. If the task is not found, return None. """ ``` #### submit_task ```python submit_task( task_id: str, session_id: str, message: Message ) -> Task ``` Submit a task to storage. Source code in `fasta2a/fasta2a/storage.py` ```python @abstractmethod async def submit_task(self, task_id: str, session_id: str, message: Message) -> Task: """Submit a task to storage.""" ``` #### update_task ```python update_task( task_id: str, state: TaskState, message: Message | None = None, artifacts: list[Artifact] | None = None, ) -> Task ``` Update the state of a task. Source code in `fasta2a/fasta2a/storage.py` ```python @abstractmethod async def update_task( self, task_id: str, state: TaskState, message: Message | None = None, artifacts: list[Artifact] | None = None, ) -> Task: """Update the state of a task.""" ``` ### Worker Bases: `ABC` A worker is responsible for executing tasks. Source code in `fasta2a/fasta2a/worker.py` ```python @dataclass class Worker(ABC): """A worker is responsible for executing tasks.""" broker: Broker storage: Storage @asynccontextmanager async def run(self) -> AsyncIterator[None]: """Run the worker. It connects to the broker, and it makes itself available to receive commands. """ async with anyio.create_task_group() as tg: tg.start_soon(self._loop) yield tg.cancel_scope.cancel() async def _loop(self) -> None: async for task_operation in self.broker.receive_task_operations(): await self._handle_task_operation(task_operation) async def _handle_task_operation(self, task_operation: TaskOperation) -> None: try: with use_span(task_operation['_current_span']): with tracer.start_as_current_span( f'{task_operation["operation"]} task', attributes={'logfire.tags': ['fasta2a']} ): if task_operation['operation'] == 'run': await self.run_task(task_operation['params']) elif task_operation['operation'] == 'cancel': await self.cancel_task(task_operation['params']) else: assert_never(task_operation) except Exception: await self.storage.update_task(task_operation['params']['id'], state='failed') @abstractmethod async def run_task(self, params: TaskSendParams) -> None: ... @abstractmethod async def cancel_task(self, params: TaskIdParams) -> None: ... @abstractmethod def build_message_history(self, task_history: list[Message]) -> list[Any]: ... @abstractmethod def build_artifacts(self, result: Any) -> list[Artifact]: ... ``` #### run ```python run() -> AsyncIterator[None] ``` Run the worker. It connects to the broker, and it makes itself available to receive commands. Source code in `fasta2a/fasta2a/worker.py` ```python @asynccontextmanager async def run(self) -> AsyncIterator[None]: """Run the worker. It connects to the broker, and it makes itself available to receive commands. """ async with anyio.create_task_group() as tg: tg.start_soon(self._loop) yield tg.cancel_scope.cancel() ``` This module contains the schema for the agent card. ### AgentCard Bases: `TypedDict` The card that describes an agent. Source code in `fasta2a/fasta2a/schema.py` ```python @pydantic.with_config({'alias_generator': to_camel}) class AgentCard(TypedDict): """The card that describes an agent.""" name: str """Human readable name of the agent e.g. "Recipe Agent".""" description: NotRequired[str] """A human-readable description of the agent. Used to assist users and other agents in understanding what the agent can do. (e.g. "Agent that helps users with recipes and cooking.") """ # TODO(Marcelo): The spec makes url required. url: NotRequired[str] """A URL to the address the agent is hosted at.""" provider: NotRequired[Provider] """The service provider of the agent.""" # TODO(Marcelo): The spec makes version required. version: NotRequired[str] """The version of the agent - format is up to the provider. (e.g. "1.0.0")""" documentation_url: NotRequired[str] """A URL to documentation for the agent.""" capabilities: Capabilities """The capabilities of the agent.""" authentication: Authentication """The authentication schemes supported by the agent. Intended to match OpenAPI authentication structure. """ default_input_modes: list[str] """Supported mime types for input data.""" default_output_modes: list[str] """Supported mime types for output data.""" skills: list[Skill] ``` #### name ```python name: str ``` Human readable name of the agent e.g. "Recipe Agent". #### description ```python description: NotRequired[str] ``` A human-readable description of the agent. Used to assist users and other agents in understanding what the agent can do. (e.g. "Agent that helps users with recipes and cooking.") #### url ```python url: NotRequired[str] ``` A URL to the address the agent is hosted at. #### provider ```python provider: NotRequired[Provider] ``` The service provider of the agent. #### version ```python version: NotRequired[str] ``` The version of the agent - format is up to the provider. (e.g. "1.0.0") #### documentation_url ```python documentation_url: NotRequired[str] ``` A URL to documentation for the agent. #### capabilities ```python capabilities: Capabilities ``` The capabilities of the agent. #### authentication ```python authentication: Authentication ``` The authentication schemes supported by the agent. Intended to match OpenAPI authentication structure. #### default_input_modes ```python default_input_modes: list[str] ``` Supported mime types for input data. #### default_output_modes ```python default_output_modes: list[str] ``` Supported mime types for output data. ### Provider Bases: `TypedDict` The service provider of the agent. Source code in `fasta2a/fasta2a/schema.py` ```python class Provider(TypedDict): """The service provider of the agent.""" organization: str url: str ``` ### Capabilities Bases: `TypedDict` The capabilities of the agent. Source code in `fasta2a/fasta2a/schema.py` ```python @pydantic.with_config({'alias_generator': to_camel}) class Capabilities(TypedDict): """The capabilities of the agent.""" streaming: NotRequired[bool] """Whether the agent supports streaming.""" push_notifications: NotRequired[bool] """Whether the agent can notify updates to client.""" state_transition_history: NotRequired[bool] """Whether the agent exposes status change history for tasks.""" ``` #### streaming ```python streaming: NotRequired[bool] ``` Whether the agent supports streaming. #### push_notifications ```python push_notifications: NotRequired[bool] ``` Whether the agent can notify updates to client. #### state_transition_history ```python state_transition_history: NotRequired[bool] ``` Whether the agent exposes status change history for tasks. ### Authentication Bases: `TypedDict` The authentication schemes supported by the agent. Source code in `fasta2a/fasta2a/schema.py` ```python @pydantic.with_config({'alias_generator': to_camel}) class Authentication(TypedDict): """The authentication schemes supported by the agent.""" schemes: list[str] """The authentication schemes supported by the agent. (e.g. "Basic", "Bearer")""" credentials: NotRequired[str] """The credentials a client should use for private cards.""" ``` #### schemes ```python schemes: list[str] ``` The authentication schemes supported by the agent. (e.g. "Basic", "Bearer") #### credentials ```python credentials: NotRequired[str] ``` The credentials a client should use for private cards. ### Skill Bases: `TypedDict` Skills are a unit of capability that an agent can perform. Source code in `fasta2a/fasta2a/schema.py` ```python @pydantic.with_config({'alias_generator': to_camel}) class Skill(TypedDict): """Skills are a unit of capability that an agent can perform.""" id: str """A unique identifier for the skill.""" name: str """Human readable name of the skill.""" description: str """A human-readable description of the skill. It will be used by the client or a human as a hint to understand the skill. """ tags: list[str] """Set of tag-words describing classes of capabilities for this specific skill. Examples: "cooking", "customer support", "billing". """ examples: NotRequired[list[str]] """The set of example scenarios that the skill can perform. Will be used by the client as a hint to understand how the skill can be used. (e.g. "I need a recipe for bread") """ input_modes: list[str] """Supported mime types for input data.""" output_modes: list[str] """Supported mime types for output data.""" ``` #### id ```python id: str ``` A unique identifier for the skill. #### name ```python name: str ``` Human readable name of the skill. #### description ```python description: str ``` A human-readable description of the skill. It will be used by the client or a human as a hint to understand the skill. #### tags ```python tags: list[str] ``` Set of tag-words describing classes of capabilities for this specific skill. Examples: "cooking", "customer support", "billing". #### examples ```python examples: NotRequired[list[str]] ``` The set of example scenarios that the skill can perform. Will be used by the client as a hint to understand how the skill can be used. (e.g. "I need a recipe for bread") #### input_modes ```python input_modes: list[str] ``` Supported mime types for input data. #### output_modes ```python output_modes: list[str] ``` Supported mime types for output data. ### Artifact Bases: `TypedDict` Agents generate Artifacts as an end result of a Task. Artifacts are immutable, can be named, and can have multiple parts. A streaming response can append parts to existing Artifacts. A single Task can generate many Artifacts. For example, "create a webpage" could create separate HTML and image Artifacts. Source code in `fasta2a/fasta2a/schema.py` ```python @pydantic.with_config({'alias_generator': to_camel}) class Artifact(TypedDict): """Agents generate Artifacts as an end result of a Task. Artifacts are immutable, can be named, and can have multiple parts. A streaming response can append parts to existing Artifacts. A single Task can generate many Artifacts. For example, "create a webpage" could create separate HTML and image Artifacts. """ name: NotRequired[str] """The name of the artifact.""" description: NotRequired[str] """A description of the artifact.""" parts: list[Part] """The parts that make up the artifact.""" metadata: NotRequired[dict[str, Any]] """Metadata about the artifact.""" index: int """The index of the artifact.""" append: NotRequired[bool] """Whether to append this artifact to an existing one.""" last_chunk: NotRequired[bool] """Whether this is the last chunk of the artifact.""" ``` #### name ```python name: NotRequired[str] ``` The name of the artifact. #### description ```python description: NotRequired[str] ``` A description of the artifact. #### parts ```python parts: list[Part] ``` The parts that make up the artifact. #### metadata ```python metadata: NotRequired[dict[str, Any]] ``` Metadata about the artifact. #### index ```python index: int ``` The index of the artifact. #### append ```python append: NotRequired[bool] ``` Whether to append this artifact to an existing one. #### last_chunk ```python last_chunk: NotRequired[bool] ``` Whether this is the last chunk of the artifact. ### PushNotificationConfig Bases: `TypedDict` Configuration for push notifications. A2A supports a secure notification mechanism whereby an agent can notify a client of an update outside of a connected session via a PushNotificationService. Within and across enterprises, it is critical that the agent verifies the identity of the notification service, authenticates itself with the service, and presents an identifier that ties the notification to the executing Task. The target server of the PushNotificationService should be considered a separate service, and is not guaranteed (or even expected) to be the client directly. This PushNotificationService is responsible for authenticating and authorizing the agent and for proxying the verified notification to the appropriate endpoint (which could be anything from a pub/sub queue, to an email inbox or other service, etc). For contrived scenarios with isolated client-agent pairs (e.g. local service mesh in a contained VPC, etc.) or isolated environments without enterprise security concerns, the client may choose to simply open a port and act as its own PushNotificationService. Any enterprise implementation will likely have a centralized service that authenticates the remote agents with trusted notification credentials and can handle online/offline scenarios. (This should be thought of similarly to a mobile Push Notification Service). Source code in `fasta2a/fasta2a/schema.py` ```python @pydantic.with_config({'alias_generator': to_camel}) class PushNotificationConfig(TypedDict): """Configuration for push notifications. A2A supports a secure notification mechanism whereby an agent can notify a client of an update outside of a connected session via a PushNotificationService. Within and across enterprises, it is critical that the agent verifies the identity of the notification service, authenticates itself with the service, and presents an identifier that ties the notification to the executing Task. The target server of the PushNotificationService should be considered a separate service, and is not guaranteed (or even expected) to be the client directly. This PushNotificationService is responsible for authenticating and authorizing the agent and for proxying the verified notification to the appropriate endpoint (which could be anything from a pub/sub queue, to an email inbox or other service, etc). For contrived scenarios with isolated client-agent pairs (e.g. local service mesh in a contained VPC, etc.) or isolated environments without enterprise security concerns, the client may choose to simply open a port and act as its own PushNotificationService. Any enterprise implementation will likely have a centralized service that authenticates the remote agents with trusted notification credentials and can handle online/offline scenarios. (This should be thought of similarly to a mobile Push Notification Service). """ url: str """The URL to send push notifications to.""" token: NotRequired[str] """Token unique to this task/session.""" authentication: NotRequired[Authentication] """Authentication details for push notifications.""" ``` #### url ```python url: str ``` The URL to send push notifications to. #### token ```python token: NotRequired[str] ``` Token unique to this task/session. #### authentication ```python authentication: NotRequired[Authentication] ``` Authentication details for push notifications. ### TaskPushNotificationConfig Bases: `TypedDict` Configuration for task push notifications. Source code in `fasta2a/fasta2a/schema.py` ```python @pydantic.with_config({'alias_generator': to_camel}) class TaskPushNotificationConfig(TypedDict): """Configuration for task push notifications.""" id: str """The task id.""" push_notification_config: PushNotificationConfig """The push notification configuration.""" ``` #### id ```python id: str ``` The task id. #### push_notification_config ```python push_notification_config: PushNotificationConfig ``` The push notification configuration. ### Message Bases: `TypedDict` A Message contains any content that is not an Artifact. This can include things like agent thoughts, user context, instructions, errors, status, or metadata. All content from a client comes in the form of a Message. Agents send Messages to communicate status or to provide instructions (whereas generated results are sent as Artifacts). A Message can have multiple parts to denote different pieces of content. For example, a user request could include a textual description from a user and then multiple files used as context from the client. Source code in `fasta2a/fasta2a/schema.py` ```python class Message(TypedDict): """A Message contains any content that is not an Artifact. This can include things like agent thoughts, user context, instructions, errors, status, or metadata. All content from a client comes in the form of a Message. Agents send Messages to communicate status or to provide instructions (whereas generated results are sent as Artifacts). A Message can have multiple parts to denote different pieces of content. For example, a user request could include a textual description from a user and then multiple files used as context from the client. """ role: Literal['user', 'agent'] """The role of the message.""" parts: list[Part] """The parts of the message.""" metadata: NotRequired[dict[str, Any]] """Metadata about the message.""" ``` #### role ```python role: Literal['user', 'agent'] ``` The role of the message. #### parts ```python parts: list[Part] ``` The parts of the message. #### metadata ```python metadata: NotRequired[dict[str, Any]] ``` Metadata about the message. ### TextPart Bases: `_BasePart` A part that contains text. Source code in `fasta2a/fasta2a/schema.py` ```python class TextPart(_BasePart): """A part that contains text.""" type: Literal['text'] """The type of the part.""" text: str """The text of the part.""" ``` #### type ```python type: Literal['text'] ``` The type of the part. #### text ```python text: str ``` The text of the part. ### FilePart Bases: `_BasePart` A part that contains a file. Source code in `fasta2a/fasta2a/schema.py` ```python @pydantic.with_config({'alias_generator': to_camel}) class FilePart(_BasePart): """A part that contains a file.""" type: Literal['file'] """The type of the part.""" file: File """The file of the part.""" ``` #### type ```python type: Literal['file'] ``` The type of the part. #### file ```python file: File ``` The file of the part. ### File ```python File: TypeAlias = Union[_BinaryFile, _URLFile] ``` A file is a binary file or a URL file. ### DataPart Bases: `_BasePart` A part that contains data. Source code in `fasta2a/fasta2a/schema.py` ```python @pydantic.with_config({'alias_generator': to_camel}) class DataPart(_BasePart): """A part that contains data.""" type: Literal['data'] """The type of the part.""" data: dict[str, Any] """The data of the part.""" ``` #### type ```python type: Literal['data'] ``` The type of the part. #### data ```python data: dict[str, Any] ``` The data of the part. ### Part ```python Part = Annotated[ Union[TextPart, FilePart, DataPart], Field(discriminator="type"), ] ``` A fully formed piece of content exchanged between a client and a remote agent as part of a Message or an Artifact. Each Part has its own content type and metadata. ### TaskState ```python TaskState: TypeAlias = Literal[ "submitted", "working", "input-required", "completed", "canceled", "failed", "unknown", ] ``` The possible states of a task. ### TaskStatus Bases: `TypedDict` Status and accompanying message for a task. Source code in `fasta2a/fasta2a/schema.py` ```python @pydantic.with_config({'alias_generator': to_camel}) class TaskStatus(TypedDict): """Status and accompanying message for a task.""" state: TaskState """The current state of the task.""" message: NotRequired[Message] """Additional status updates for client.""" timestamp: NotRequired[str] """ISO datetime value of when the status was updated.""" ``` #### state ```python state: TaskState ``` The current state of the task. #### message ```python message: NotRequired[Message] ``` Additional status updates for client. #### timestamp ```python timestamp: NotRequired[str] ``` ISO datetime value of when the status was updated. ### Task Bases: `TypedDict` A Task is a stateful entity that allows Clients and Remote Agents to achieve a specific outcome. Clients and Remote Agents exchange Messages within a Task. Remote Agents generate results as Artifacts. A Task is always created by a Client and the status is always determined by the Remote Agent. Source code in `fasta2a/fasta2a/schema.py` ```python @pydantic.with_config({'alias_generator': to_camel}) class Task(TypedDict): """A Task is a stateful entity that allows Clients and Remote Agents to achieve a specific outcome. Clients and Remote Agents exchange Messages within a Task. Remote Agents generate results as Artifacts. A Task is always created by a Client and the status is always determined by the Remote Agent. """ id: str """Unique identifier for the task.""" session_id: NotRequired[str] """Client-generated id for the session holding the task.""" status: TaskStatus """Current status of the task.""" history: NotRequired[list[Message]] """Optional history of messages.""" artifacts: NotRequired[list[Artifact]] """Collection of artifacts created by the agent.""" metadata: NotRequired[dict[str, Any]] """Extension metadata.""" ``` #### id ```python id: str ``` Unique identifier for the task. #### session_id ```python session_id: NotRequired[str] ``` Client-generated id for the session holding the task. #### status ```python status: TaskStatus ``` Current status of the task. #### history ```python history: NotRequired[list[Message]] ``` Optional history of messages. #### artifacts ```python artifacts: NotRequired[list[Artifact]] ``` Collection of artifacts created by the agent. #### metadata ```python metadata: NotRequired[dict[str, Any]] ``` Extension metadata. ### TaskStatusUpdateEvent Bases: `TypedDict` Sent by server during sendSubscribe or subscribe requests. Source code in `fasta2a/fasta2a/schema.py` ```python @pydantic.with_config({'alias_generator': to_camel}) class TaskStatusUpdateEvent(TypedDict): """Sent by server during sendSubscribe or subscribe requests.""" id: str """The id of the task.""" status: TaskStatus """The status of the task.""" final: bool """Indicates the end of the event stream.""" metadata: NotRequired[dict[str, Any]] """Extension metadata.""" ``` #### id ```python id: str ``` The id of the task. #### status ```python status: TaskStatus ``` The status of the task. #### final ```python final: bool ``` Indicates the end of the event stream. #### metadata ```python metadata: NotRequired[dict[str, Any]] ``` Extension metadata. ### TaskArtifactUpdateEvent Bases: `TypedDict` Sent by server during sendSubscribe or subscribe requests. Source code in `fasta2a/fasta2a/schema.py` ```python @pydantic.with_config({'alias_generator': to_camel}) class TaskArtifactUpdateEvent(TypedDict): """Sent by server during sendSubscribe or subscribe requests.""" id: str """The id of the task.""" artifact: Artifact """The artifact that was updated.""" metadata: NotRequired[dict[str, Any]] """Extension metadata.""" ``` #### id ```python id: str ``` The id of the task. #### artifact ```python artifact: Artifact ``` The artifact that was updated. #### metadata ```python metadata: NotRequired[dict[str, Any]] ``` Extension metadata. ### TaskIdParams Bases: `TypedDict` Parameters for a task id. Source code in `fasta2a/fasta2a/schema.py` ```python @pydantic.with_config({'alias_generator': to_camel}) class TaskIdParams(TypedDict): """Parameters for a task id.""" id: str metadata: NotRequired[dict[str, Any]] ``` ### TaskQueryParams Bases: `TaskIdParams` Query parameters for a task. Source code in `fasta2a/fasta2a/schema.py` ```python @pydantic.with_config({'alias_generator': to_camel}) class TaskQueryParams(TaskIdParams): """Query parameters for a task.""" history_length: NotRequired[int] """Number of recent messages to be retrieved.""" ``` #### history_length ```python history_length: NotRequired[int] ``` Number of recent messages to be retrieved. ### TaskSendParams Bases: `TypedDict` Sent by the client to the agent to create, continue, or restart a task. Source code in `fasta2a/fasta2a/schema.py` ```python @pydantic.with_config({'alias_generator': to_camel}) class TaskSendParams(TypedDict): """Sent by the client to the agent to create, continue, or restart a task.""" id: str """The id of the task.""" session_id: NotRequired[str] """The server creates a new sessionId for new tasks if not set.""" message: Message """The message to send to the agent.""" history_length: NotRequired[int] """Number of recent messages to be retrieved.""" push_notification: NotRequired[PushNotificationConfig] """Where the server should send notifications when disconnected.""" metadata: NotRequired[dict[str, Any]] """Extension metadata.""" ``` #### id ```python id: str ``` The id of the task. #### session_id ```python session_id: NotRequired[str] ``` The server creates a new sessionId for new tasks if not set. #### message ```python message: Message ``` The message to send to the agent. #### history_length ```python history_length: NotRequired[int] ``` Number of recent messages to be retrieved. #### push_notification ```python push_notification: NotRequired[PushNotificationConfig] ``` Where the server should send notifications when disconnected. #### metadata ```python metadata: NotRequired[dict[str, Any]] ``` Extension metadata. ### JSONRPCMessage Bases: `TypedDict` A JSON RPC message. Source code in `fasta2a/fasta2a/schema.py` ```python class JSONRPCMessage(TypedDict): """A JSON RPC message.""" jsonrpc: Literal['2.0'] """The JSON RPC version.""" id: int | str | None """The request id.""" ``` #### jsonrpc ```python jsonrpc: Literal['2.0'] ``` The JSON RPC version. #### id ```python id: int | str | None ``` The request id. ### JSONRPCRequest Bases: `JSONRPCMessage`, `Generic[Method, Params]` A JSON RPC request. Source code in `fasta2a/fasta2a/schema.py` ```python class JSONRPCRequest(JSONRPCMessage, Generic[Method, Params]): """A JSON RPC request.""" method: Method """The method to call.""" params: Params """The parameters to pass to the method.""" ``` #### method ```python method: Method ``` The method to call. #### params ```python params: Params ``` The parameters to pass to the method. ### JSONRPCError Bases: `TypedDict`, `Generic[CodeT, MessageT]` A JSON RPC error. Source code in `fasta2a/fasta2a/schema.py` ```python class JSONRPCError(TypedDict, Generic[CodeT, MessageT]): """A JSON RPC error.""" code: CodeT message: MessageT data: NotRequired[Any] ``` ### JSONRPCResponse Bases: `JSONRPCMessage`, `Generic[ResultT, ErrorT]` A JSON RPC response. Source code in `fasta2a/fasta2a/schema.py` ```python class JSONRPCResponse(JSONRPCMessage, Generic[ResultT, ErrorT]): """A JSON RPC response.""" result: NotRequired[ResultT] error: NotRequired[ErrorT] ``` ### JSONParseError ```python JSONParseError = JSONRPCError[ Literal[-32700], Literal["Invalid JSON payload"] ] ``` A JSON RPC error for a parse error. ### InvalidRequestError ```python InvalidRequestError = JSONRPCError[ Literal[-32600], Literal["Request payload validation error"], ] ``` A JSON RPC error for an invalid request. ### MethodNotFoundError ```python MethodNotFoundError = JSONRPCError[ Literal[-32601], Literal["Method not found"] ] ``` A JSON RPC error for a method not found. ### InvalidParamsError ```python InvalidParamsError = JSONRPCError[ Literal[-32602], Literal["Invalid parameters"] ] ``` A JSON RPC error for invalid parameters. ### InternalError ```python InternalError = JSONRPCError[ Literal[-32603], Literal["Internal error"] ] ``` A JSON RPC error for an internal error. ### TaskNotFoundError ```python TaskNotFoundError = JSONRPCError[ Literal[-32001], Literal["Task not found"] ] ``` A JSON RPC error for a task not found. ### TaskNotCancelableError ```python TaskNotCancelableError = JSONRPCError[ Literal[-32002], Literal["Task not cancelable"] ] ``` A JSON RPC error for a task not cancelable. ### PushNotificationNotSupportedError ```python PushNotificationNotSupportedError = JSONRPCError[ Literal[-32003], Literal["Push notification not supported"], ] ``` A JSON RPC error for a push notification not supported. ### UnsupportedOperationError ```python UnsupportedOperationError = JSONRPCError[ Literal[-32004], Literal["This operation is not supported"], ] ``` A JSON RPC error for an unsupported operation. ### ContentTypeNotSupportedError ```python ContentTypeNotSupportedError = JSONRPCError[ Literal[-32005], Literal["Incompatible content types"] ] ``` A JSON RPC error for incompatible content types. ### SendTaskRequest ```python SendTaskRequest = JSONRPCRequest[ Literal["tasks/send"], TaskSendParams ] ``` A JSON RPC request to send a task. ### SendTaskResponse ```python SendTaskResponse = JSONRPCResponse[ Task, JSONRPCError[Any, Any] ] ``` A JSON RPC response to send a task. ### SendTaskStreamingRequest ```python SendTaskStreamingRequest = JSONRPCRequest[ Literal["tasks/sendSubscribe"], TaskSendParams ] ``` A JSON RPC request to send a task and receive updates. ### SendTaskStreamingResponse ```python SendTaskStreamingResponse = JSONRPCResponse[ Union[TaskStatusUpdateEvent, TaskArtifactUpdateEvent], InternalError, ] ``` A JSON RPC response to send a task and receive updates. ### GetTaskRequest ```python GetTaskRequest = JSONRPCRequest[ Literal["tasks/get"], TaskQueryParams ] ``` A JSON RPC request to get a task. ### GetTaskResponse ```python GetTaskResponse = JSONRPCResponse[Task, TaskNotFoundError] ``` A JSON RPC response to get a task. ### CancelTaskRequest ```python CancelTaskRequest = JSONRPCRequest[ Literal["tasks/cancel"], TaskIdParams ] ``` A JSON RPC request to cancel a task. ### CancelTaskResponse ```python CancelTaskResponse = JSONRPCResponse[ Task, Union[TaskNotCancelableError, TaskNotFoundError] ] ``` A JSON RPC response to cancel a task. ### SetTaskPushNotificationRequest ```python SetTaskPushNotificationRequest = JSONRPCRequest[ Literal["tasks/pushNotification/set"], TaskPushNotificationConfig, ] ``` A JSON RPC request to set a task push notification. ### SetTaskPushNotificationResponse ```python SetTaskPushNotificationResponse = JSONRPCResponse[ TaskPushNotificationConfig, PushNotificationNotSupportedError, ] ``` A JSON RPC response to set a task push notification. ### GetTaskPushNotificationRequest ```python GetTaskPushNotificationRequest = JSONRPCRequest[ Literal["tasks/pushNotification/get"], TaskIdParams ] ``` A JSON RPC request to get a task push notification. ### GetTaskPushNotificationResponse ```python GetTaskPushNotificationResponse = JSONRPCResponse[ TaskPushNotificationConfig, PushNotificationNotSupportedError, ] ``` A JSON RPC response to get a task push notification. ### ResubscribeTaskRequest ```python ResubscribeTaskRequest = JSONRPCRequest[ Literal["tasks/resubscribe"], TaskIdParams ] ``` A JSON RPC request to resubscribe to a task. ### A2ARequest ```python A2ARequest = Annotated[ Union[ SendTaskRequest, GetTaskRequest, CancelTaskRequest, SetTaskPushNotificationRequest, GetTaskPushNotificationRequest, ResubscribeTaskRequest, ], Discriminator("method"), ] ``` A JSON RPC request to the A2A server. ### A2AResponse ```python A2AResponse: TypeAlias = Union[ SendTaskResponse, GetTaskResponse, CancelTaskResponse, SetTaskPushNotificationResponse, GetTaskPushNotificationResponse, ] ``` A JSON RPC response from the A2A server. ### A2AClient A client for the A2A protocol. Source code in `fasta2a/fasta2a/client.py` ```python class A2AClient: """A client for the A2A protocol.""" def __init__(self, base_url: str = 'http://localhost:8000', http_client: httpx.AsyncClient | None = None) -> None: if http_client is None: self.http_client = httpx.AsyncClient(base_url=base_url) else: self.http_client = http_client self.http_client.base_url = base_url async def send_task( self, message: Message, history_length: int | None = None, push_notification: PushNotificationConfig | None = None, metadata: dict[str, Any] | None = None, ) -> SendTaskResponse: task = TaskSendParams(message=message, id=str(uuid.uuid4())) if history_length is not None: task['history_length'] = history_length if push_notification is not None: task['push_notification'] = push_notification if metadata is not None: task['metadata'] = metadata payload = SendTaskRequest(jsonrpc='2.0', id=None, method='tasks/send', params=task) content = a2a_request_ta.dump_json(payload, by_alias=True) response = await self.http_client.post('/', content=content, headers={'Content-Type': 'application/json'}) self._raise_for_status(response) return send_task_response_ta.validate_json(response.content) async def get_task(self, task_id: str) -> GetTaskResponse: payload = GetTaskRequest(jsonrpc='2.0', id=None, method='tasks/get', params={'id': task_id}) content = a2a_request_ta.dump_json(payload, by_alias=True) response = await self.http_client.post('/', content=content, headers={'Content-Type': 'application/json'}) self._raise_for_status(response) return get_task_response_ta.validate_json(response.content) def _raise_for_status(self, response: httpx.Response) -> None: if response.status_code >= 400: raise UnexpectedResponseError(response.status_code, response.text) ``` ### UnexpectedResponseError Bases: `Exception` An error raised when an unexpected response is received from the server. Source code in `fasta2a/fasta2a/client.py` ```python class UnexpectedResponseError(Exception): """An error raised when an unexpected response is received from the server.""" def __init__(self, status_code: int, content: str) -> None: self.status_code = status_code self.content = content ``` # `pydantic_ai.format_as_xml` ### format_as_xml ```python format_as_xml(prompt: str) -> str ``` `format_as_xml` has moved, import it via `from pydantic_ai import format_as_xml` instead. Source code in `pydantic_ai_slim/pydantic_ai/format_as_xml.py` ```python @deprecated('`format_as_xml` has moved, import it via `from pydantic_ai import format_as_xml`') def format_as_xml(prompt: str) -> str: """`format_as_xml` has moved, import it via `from pydantic_ai import format_as_xml` instead.""" return _format_as_xml(prompt) ``` # `pydantic_ai.format_prompt` ### format_as_xml ```python format_as_xml( obj: Any, root_tag: str = "examples", item_tag: str = "example", include_root_tag: bool = True, none_str: str = "null", indent: str | None = " ", ) -> str ``` Format a Python object as XML. This is useful since LLMs often find it easier to read semi-structured data (e.g. examples) as XML, rather than JSON etc. Supports: `str`, `bytes`, `bytearray`, `bool`, `int`, `float`, `date`, `datetime`, `Mapping`, `Iterable`, `dataclass`, and `BaseModel`. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `obj` | `Any` | Python Object to serialize to XML. | *required* | | `root_tag` | `str` | Outer tag to wrap the XML in, use None to omit the outer tag. | `'examples'` | | `item_tag` | `str` | Tag to use for each item in an iterable (e.g. list), this is overridden by the class name for dataclasses and Pydantic models. | `'example'` | | `include_root_tag` | `bool` | Whether to include the root tag in the output (The root tag is always included if it includes a body - e.g. when the input is a simple value). | `True` | | `none_str` | `str` | String to use for None values. | `'null'` | | `indent` | `str | None` | Indentation string to use for pretty printing. | `' '` | Returns: | Type | Description | | --- | --- | | `str` | XML representation of the object. | Example: format_as_xml_example.py ```python from pydantic_ai import format_as_xml print(format_as_xml({'name': 'John', 'height': 6, 'weight': 200}, root_tag='user')) ''' John 6 200 ''' ``` Source code in `pydantic_ai_slim/pydantic_ai/format_prompt.py` ````python def format_as_xml( obj: Any, root_tag: str = 'examples', item_tag: str = 'example', include_root_tag: bool = True, none_str: str = 'null', indent: str | None = ' ', ) -> str: """Format a Python object as XML. This is useful since LLMs often find it easier to read semi-structured data (e.g. examples) as XML, rather than JSON etc. Supports: `str`, `bytes`, `bytearray`, `bool`, `int`, `float`, `date`, `datetime`, `Mapping`, `Iterable`, `dataclass`, and `BaseModel`. Args: obj: Python Object to serialize to XML. root_tag: Outer tag to wrap the XML in, use `None` to omit the outer tag. item_tag: Tag to use for each item in an iterable (e.g. list), this is overridden by the class name for dataclasses and Pydantic models. include_root_tag: Whether to include the root tag in the output (The root tag is always included if it includes a body - e.g. when the input is a simple value). none_str: String to use for `None` values. indent: Indentation string to use for pretty printing. Returns: XML representation of the object. Example: ```python {title="format_as_xml_example.py" lint="skip"} from pydantic_ai import format_as_xml print(format_as_xml({'name': 'John', 'height': 6, 'weight': 200}, root_tag='user')) ''' John 6 200 ''' ``` """ el = _ToXml(item_tag=item_tag, none_str=none_str).to_xml(obj, root_tag) if not include_root_tag and el.text is None: join = '' if indent is None else '\n' return join.join(_rootless_xml_elements(el, indent)) else: if indent is not None: ElementTree.indent(el, space=indent) return ElementTree.tostring(el, encoding='unicode') ```` # `pydantic_ai.mcp` ### MCPServer Bases: `ABC` Base class for attaching agents to MCP servers. See for more information. Source code in `pydantic_ai_slim/pydantic_ai/mcp.py` ```python class MCPServer(ABC): """Base class for attaching agents to MCP servers. See for more information. """ # these fields should be re-defined by dataclass subclasses so they appear as fields { tool_prefix: str | None = None log_level: mcp_types.LoggingLevel | None = None log_handler: LoggingFnT | None = None timeout: float = 5 process_tool_call: ProcessToolCallback | None = None allow_sampling: bool = True # } end of "abstract fields" _running_count: int = 0 _client: ClientSession _read_stream: MemoryObjectReceiveStream[SessionMessage | Exception] _write_stream: MemoryObjectSendStream[SessionMessage] _exit_stack: AsyncExitStack sampling_model: models.Model | None = None @abstractmethod @asynccontextmanager async def client_streams( self, ) -> AsyncIterator[ tuple[ MemoryObjectReceiveStream[SessionMessage | Exception], MemoryObjectSendStream[SessionMessage], ] ]: """Create the streams for the MCP server.""" raise NotImplementedError('MCP Server subclasses must implement this method.') yield def get_prefixed_tool_name(self, tool_name: str) -> str: """Get the tool name with prefix if `tool_prefix` is set.""" return f'{self.tool_prefix}_{tool_name}' if self.tool_prefix else tool_name def get_unprefixed_tool_name(self, tool_name: str) -> str: """Get original tool name without prefix for calling tools.""" return tool_name.removeprefix(f'{self.tool_prefix}_') if self.tool_prefix else tool_name @property def is_running(self) -> bool: """Check if the MCP server is running.""" return bool(self._running_count) async def list_tools(self) -> list[tools.ToolDefinition]: """Retrieve tools that are currently active on the server. Note: - We don't cache tools as they might change. - We also don't subscribe to the server to avoid complexity. """ mcp_tools = await self._client.list_tools() return [ tools.ToolDefinition( name=self.get_prefixed_tool_name(tool.name), description=tool.description or '', parameters_json_schema=tool.inputSchema, ) for tool in mcp_tools.tools ] async def call_tool( self, tool_name: str, arguments: dict[str, Any], metadata: dict[str, Any] | None = None, ) -> ToolResult: """Call a tool on the server. Args: tool_name: The name of the tool to call. arguments: The arguments to pass to the tool. metadata: Request-level metadata (optional) Returns: The result of the tool call. Raises: ModelRetry: If the tool call fails. """ try: # meta param is not provided by session yet, so build and can send_request directly. result = await self._client.send_request( mcp_types.ClientRequest( mcp_types.CallToolRequest( method='tools/call', params=mcp_types.CallToolRequestParams( name=self.get_unprefixed_tool_name(tool_name), arguments=arguments, _meta=mcp_types.RequestParams.Meta(**metadata) if metadata else None, ), ) ), mcp_types.CallToolResult, ) except McpError as e: raise exceptions.ModelRetry(e.error.message) content = [self._map_tool_result_part(part) for part in result.content] if result.isError: text = '\n'.join(str(part) for part in content) raise exceptions.ModelRetry(text) else: return content[0] if len(content) == 1 else content async def __aenter__(self) -> Self: if self._running_count == 0: self._exit_stack = AsyncExitStack() self._read_stream, self._write_stream = await self._exit_stack.enter_async_context(self.client_streams()) client = ClientSession( read_stream=self._read_stream, write_stream=self._write_stream, sampling_callback=self._sampling_callback if self.allow_sampling else None, logging_callback=self.log_handler, ) self._client = await self._exit_stack.enter_async_context(client) with anyio.fail_after(self.timeout): await self._client.initialize() if log_level := self.log_level: await self._client.set_logging_level(log_level) self._running_count += 1 return self async def __aexit__( self, exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: TracebackType | None, ) -> bool | None: self._running_count -= 1 if self._running_count <= 0: await self._exit_stack.aclose() async def _sampling_callback( self, context: RequestContext[ClientSession, Any], params: mcp_types.CreateMessageRequestParams ) -> mcp_types.CreateMessageResult | mcp_types.ErrorData: """MCP sampling callback.""" if self.sampling_model is None: raise ValueError('Sampling model is not set') # pragma: no cover pai_messages = _mcp.map_from_mcp_params(params) model_settings = models.ModelSettings() if max_tokens := params.maxTokens: # pragma: no branch model_settings['max_tokens'] = max_tokens if temperature := params.temperature: # pragma: no branch model_settings['temperature'] = temperature if stop_sequences := params.stopSequences: # pragma: no branch model_settings['stop_sequences'] = stop_sequences model_response = await self.sampling_model.request( pai_messages, model_settings, models.ModelRequestParameters(), ) return mcp_types.CreateMessageResult( role='assistant', content=_mcp.map_from_model_response(model_response), model=self.sampling_model.model_name, ) def _map_tool_result_part( self, part: mcp_types.Content ) -> str | messages.BinaryContent | dict[str, Any] | list[Any]: # See https://github.com/jlowin/fastmcp/blob/main/docs/servers/tools.mdx#return-values if isinstance(part, mcp_types.TextContent): text = part.text if text.startswith(('[', '{')): try: return pydantic_core.from_json(text) except ValueError: pass return text elif isinstance(part, mcp_types.ImageContent): return messages.BinaryContent(data=base64.b64decode(part.data), media_type=part.mimeType) elif isinstance(part, mcp_types.AudioContent): # NOTE: The FastMCP server doesn't support audio content. # See for more details. return messages.BinaryContent( data=base64.b64decode(part.data), media_type=part.mimeType ) # pragma: no cover elif isinstance(part, mcp_types.EmbeddedResource): resource = part.resource if isinstance(resource, mcp_types.TextResourceContents): return resource.text elif isinstance(resource, mcp_types.BlobResourceContents): return messages.BinaryContent( data=base64.b64decode(resource.blob), media_type=resource.mimeType or 'application/octet-stream', ) else: assert_never(resource) else: assert_never(part) ``` #### client_streams ```python client_streams() -> AsyncIterator[ tuple[ MemoryObjectReceiveStream[ SessionMessage | Exception ], MemoryObjectSendStream[SessionMessage], ] ] ``` Create the streams for the MCP server. Source code in `pydantic_ai_slim/pydantic_ai/mcp.py` ```python @abstractmethod @asynccontextmanager async def client_streams( self, ) -> AsyncIterator[ tuple[ MemoryObjectReceiveStream[SessionMessage | Exception], MemoryObjectSendStream[SessionMessage], ] ]: """Create the streams for the MCP server.""" raise NotImplementedError('MCP Server subclasses must implement this method.') yield ``` #### get_prefixed_tool_name ```python get_prefixed_tool_name(tool_name: str) -> str ``` Get the tool name with prefix if `tool_prefix` is set. Source code in `pydantic_ai_slim/pydantic_ai/mcp.py` ```python def get_prefixed_tool_name(self, tool_name: str) -> str: """Get the tool name with prefix if `tool_prefix` is set.""" return f'{self.tool_prefix}_{tool_name}' if self.tool_prefix else tool_name ``` #### get_unprefixed_tool_name ```python get_unprefixed_tool_name(tool_name: str) -> str ``` Get original tool name without prefix for calling tools. Source code in `pydantic_ai_slim/pydantic_ai/mcp.py` ```python def get_unprefixed_tool_name(self, tool_name: str) -> str: """Get original tool name without prefix for calling tools.""" return tool_name.removeprefix(f'{self.tool_prefix}_') if self.tool_prefix else tool_name ``` #### is_running ```python is_running: bool ``` Check if the MCP server is running. #### list_tools ```python list_tools() -> list[ToolDefinition] ``` Retrieve tools that are currently active on the server. Note: - We don't cache tools as they might change. - We also don't subscribe to the server to avoid complexity. Source code in `pydantic_ai_slim/pydantic_ai/mcp.py` ```python async def list_tools(self) -> list[tools.ToolDefinition]: """Retrieve tools that are currently active on the server. Note: - We don't cache tools as they might change. - We also don't subscribe to the server to avoid complexity. """ mcp_tools = await self._client.list_tools() return [ tools.ToolDefinition( name=self.get_prefixed_tool_name(tool.name), description=tool.description or '', parameters_json_schema=tool.inputSchema, ) for tool in mcp_tools.tools ] ``` #### call_tool ```python call_tool( tool_name: str, arguments: dict[str, Any], metadata: dict[str, Any] | None = None, ) -> ToolResult ``` Call a tool on the server. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `tool_name` | `str` | The name of the tool to call. | *required* | | `arguments` | `dict[str, Any]` | The arguments to pass to the tool. | *required* | | `metadata` | `dict[str, Any] | None` | Request-level metadata (optional) | `None` | Returns: | Type | Description | | --- | --- | | `ToolResult` | The result of the tool call. | Raises: | Type | Description | | --- | --- | | `ModelRetry` | If the tool call fails. | Source code in `pydantic_ai_slim/pydantic_ai/mcp.py` ```python async def call_tool( self, tool_name: str, arguments: dict[str, Any], metadata: dict[str, Any] | None = None, ) -> ToolResult: """Call a tool on the server. Args: tool_name: The name of the tool to call. arguments: The arguments to pass to the tool. metadata: Request-level metadata (optional) Returns: The result of the tool call. Raises: ModelRetry: If the tool call fails. """ try: # meta param is not provided by session yet, so build and can send_request directly. result = await self._client.send_request( mcp_types.ClientRequest( mcp_types.CallToolRequest( method='tools/call', params=mcp_types.CallToolRequestParams( name=self.get_unprefixed_tool_name(tool_name), arguments=arguments, _meta=mcp_types.RequestParams.Meta(**metadata) if metadata else None, ), ) ), mcp_types.CallToolResult, ) except McpError as e: raise exceptions.ModelRetry(e.error.message) content = [self._map_tool_result_part(part) for part in result.content] if result.isError: text = '\n'.join(str(part) for part in content) raise exceptions.ModelRetry(text) else: return content[0] if len(content) == 1 else content ``` ### MCPServerStdio Bases: `MCPServer` Runs an MCP server in a subprocess and communicates with it over stdin/stdout. This class implements the stdio transport from the MCP specification. See for more information. Note Using this class as an async context manager will start the server as a subprocess when entering the context, and stop it when exiting the context. Example: ```python from pydantic_ai import Agent from pydantic_ai.mcp import MCPServerStdio server = MCPServerStdio( # (1)! 'deno', args=[ 'run', '-N', '-R=node_modules', '-W=node_modules', '--node-modules-dir=auto', 'jsr:@pydantic/mcp-run-python', 'stdio', ] ) agent = Agent('openai:gpt-4o', mcp_servers=[server]) async def main(): async with agent.run_mcp_servers(): # (2)! ... ``` 1. See [MCP Run Python](../../mcp/run-python/) for more information. 1. This will start the server as a subprocess and connect to it. Source code in `pydantic_ai_slim/pydantic_ai/mcp.py` ````python @dataclass class MCPServerStdio(MCPServer): """Runs an MCP server in a subprocess and communicates with it over stdin/stdout. This class implements the stdio transport from the MCP specification. See for more information. !!! note Using this class as an async context manager will start the server as a subprocess when entering the context, and stop it when exiting the context. Example: ```python {py="3.10"} from pydantic_ai import Agent from pydantic_ai.mcp import MCPServerStdio server = MCPServerStdio( # (1)! 'deno', args=[ 'run', '-N', '-R=node_modules', '-W=node_modules', '--node-modules-dir=auto', 'jsr:@pydantic/mcp-run-python', 'stdio', ] ) agent = Agent('openai:gpt-4o', mcp_servers=[server]) async def main(): async with agent.run_mcp_servers(): # (2)! ... ``` 1. See [MCP Run Python](../mcp/run-python.md) for more information. 2. This will start the server as a subprocess and connect to it. """ command: str """The command to run.""" args: Sequence[str] """The arguments to pass to the command.""" env: dict[str, str] | None = None """The environment variables the CLI server will have access to. By default the subprocess will not inherit any environment variables from the parent process. If you want to inherit the environment variables from the parent process, use `env=os.environ`. """ cwd: str | Path | None = None """The working directory to use when spawning the process.""" # last fields are re-defined from the parent class so they appear as fields tool_prefix: str | None = None """A prefix to add to all tools that are registered with the server. If not empty, will include a trailing underscore(`_`). e.g. if `tool_prefix='foo'`, then a tool named `bar` will be registered as `foo_bar` """ log_level: mcp_types.LoggingLevel | None = None """The log level to set when connecting to the server, if any. See for more details. If `None`, no log level will be set. """ log_handler: LoggingFnT | None = None """A handler for logging messages from the server.""" timeout: float = 5 """The timeout in seconds to wait for the client to initialize.""" process_tool_call: ProcessToolCallback | None = None """Hook to customize tool calling and optionally pass extra metadata.""" allow_sampling: bool = True """Whether to allow MCP sampling through this client.""" @asynccontextmanager async def client_streams( self, ) -> AsyncIterator[ tuple[ MemoryObjectReceiveStream[SessionMessage | Exception], MemoryObjectSendStream[SessionMessage], ] ]: server = StdioServerParameters(command=self.command, args=list(self.args), env=self.env, cwd=self.cwd) async with stdio_client(server=server) as (read_stream, write_stream): yield read_stream, write_stream def __repr__(self) -> str: return f'MCPServerStdio(command={self.command!r}, args={self.args!r}, tool_prefix={self.tool_prefix!r})' ```` #### command ```python command: str ``` The command to run. #### args ```python args: Sequence[str] ``` The arguments to pass to the command. #### env ```python env: dict[str, str] | None = None ``` The environment variables the CLI server will have access to. By default the subprocess will not inherit any environment variables from the parent process. If you want to inherit the environment variables from the parent process, use `env=os.environ`. #### cwd ```python cwd: str | Path | None = None ``` The working directory to use when spawning the process. #### tool_prefix ```python tool_prefix: str | None = None ``` A prefix to add to all tools that are registered with the server. If not empty, will include a trailing underscore(`_`). e.g. if `tool_prefix='foo'`, then a tool named `bar` will be registered as `foo_bar` #### log_level ```python log_level: LoggingLevel | None = None ``` The log level to set when connecting to the server, if any. See for more details. If `None`, no log level will be set. #### log_handler ```python log_handler: LoggingFnT | None = None ``` A handler for logging messages from the server. #### timeout ```python timeout: float = 5 ``` The timeout in seconds to wait for the client to initialize. #### process_tool_call ```python process_tool_call: ProcessToolCallback | None = None ``` Hook to customize tool calling and optionally pass extra metadata. #### allow_sampling ```python allow_sampling: bool = True ``` Whether to allow MCP sampling through this client. ### MCPServerSSE Bases: `_MCPServerHTTP` An MCP server that connects over streamable HTTP connections. This class implements the SSE transport from the MCP specification. See for more information. Note Using this class as an async context manager will create a new pool of HTTP connections to connect to a server which should already be running. Example: ```python from pydantic_ai import Agent from pydantic_ai.mcp import MCPServerSSE server = MCPServerSSE('http://localhost:3001/sse') # (1)! agent = Agent('openai:gpt-4o', mcp_servers=[server]) async def main(): async with agent.run_mcp_servers(): # (2)! ... ``` 1. E.g. you might be connecting to a server run with [`mcp-run-python`](../../mcp/run-python/). 1. This will connect to a server running on `localhost:3001`. Source code in `pydantic_ai_slim/pydantic_ai/mcp.py` ````python @dataclass class MCPServerSSE(_MCPServerHTTP): """An MCP server that connects over streamable HTTP connections. This class implements the SSE transport from the MCP specification. See for more information. !!! note Using this class as an async context manager will create a new pool of HTTP connections to connect to a server which should already be running. Example: ```python {py="3.10"} from pydantic_ai import Agent from pydantic_ai.mcp import MCPServerSSE server = MCPServerSSE('http://localhost:3001/sse') # (1)! agent = Agent('openai:gpt-4o', mcp_servers=[server]) async def main(): async with agent.run_mcp_servers(): # (2)! ... ``` 1. E.g. you might be connecting to a server run with [`mcp-run-python`](../mcp/run-python.md). 2. This will connect to a server running on `localhost:3001`. """ @property def _transport_client(self): return sse_client # pragma: no cover ```` ### MCPServerHTTP Bases: `MCPServerSSE` An MCP server that connects over HTTP using the old SSE transport. This class implements the SSE transport from the MCP specification. See for more information. Note Using this class as an async context manager will create a new pool of HTTP connections to connect to a server which should already be running. Example: ```python from pydantic_ai import Agent from pydantic_ai.mcp import MCPServerHTTP server = MCPServerHTTP('http://localhost:3001/sse') # (1)! agent = Agent('openai:gpt-4o', mcp_servers=[server]) async def main(): async with agent.run_mcp_servers(): # (2)! ... ``` 1. E.g. you might be connecting to a server run with [`mcp-run-python`](../../mcp/run-python/). 1. This will connect to a server running on `localhost:3001`. Source code in `pydantic_ai_slim/pydantic_ai/mcp.py` ````python @deprecated('The `MCPServerHTTP` class is deprecated, use `MCPServerSSE` instead.') @dataclass class MCPServerHTTP(MCPServerSSE): """An MCP server that connects over HTTP using the old SSE transport. This class implements the SSE transport from the MCP specification. See for more information. !!! note Using this class as an async context manager will create a new pool of HTTP connections to connect to a server which should already be running. Example: ```python {py="3.10" test="skip"} from pydantic_ai import Agent from pydantic_ai.mcp import MCPServerHTTP server = MCPServerHTTP('http://localhost:3001/sse') # (1)! agent = Agent('openai:gpt-4o', mcp_servers=[server]) async def main(): async with agent.run_mcp_servers(): # (2)! ... ``` 1. E.g. you might be connecting to a server run with [`mcp-run-python`](../mcp/run-python.md). 2. This will connect to a server running on `localhost:3001`. """ ```` ### MCPServerStreamableHTTP Bases: `_MCPServerHTTP` An MCP server that connects over HTTP using the Streamable HTTP transport. This class implements the Streamable HTTP transport from the MCP specification. See for more information. Note Using this class as an async context manager will create a new pool of HTTP connections to connect to a server which should already be running. Example: ```python from pydantic_ai import Agent from pydantic_ai.mcp import MCPServerStreamableHTTP server = MCPServerStreamableHTTP('http://localhost:8000/mcp') # (1)! agent = Agent('openai:gpt-4o', mcp_servers=[server]) async def main(): async with agent.run_mcp_servers(): # (2)! ... ``` Source code in `pydantic_ai_slim/pydantic_ai/mcp.py` ````python @dataclass class MCPServerStreamableHTTP(_MCPServerHTTP): """An MCP server that connects over HTTP using the Streamable HTTP transport. This class implements the Streamable HTTP transport from the MCP specification. See for more information. !!! note Using this class as an async context manager will create a new pool of HTTP connections to connect to a server which should already be running. Example: ```python {py="3.10"} from pydantic_ai import Agent from pydantic_ai.mcp import MCPServerStreamableHTTP server = MCPServerStreamableHTTP('http://localhost:8000/mcp') # (1)! agent = Agent('openai:gpt-4o', mcp_servers=[server]) async def main(): async with agent.run_mcp_servers(): # (2)! ... ``` """ @property def _transport_client(self): return streamablehttp_client # pragma: no cover ```` # `pydantic_ai.messages` The structure of ModelMessage can be shown as a graph: ``` graph RL SystemPromptPart(SystemPromptPart) --- ModelRequestPart UserPromptPart(UserPromptPart) --- ModelRequestPart ToolReturnPart(ToolReturnPart) --- ModelRequestPart RetryPromptPart(RetryPromptPart) --- ModelRequestPart TextPart(TextPart) --- ModelResponsePart ToolCallPart(ToolCallPart) --- ModelResponsePart ModelRequestPart("ModelRequestPart
(Union)") --- ModelRequest ModelRequest("ModelRequest(parts=list[...])") --- ModelMessage ModelResponsePart("ModelResponsePart
(Union)") --- ModelResponse ModelResponse("ModelResponse(parts=list[...])") --- ModelMessage("ModelMessage
(Union)") ``` ### SystemPromptPart A system prompt, generally written by the application developer. This gives the model context and guidance on how to respond. Source code in `pydantic_ai_slim/pydantic_ai/messages.py` ```python @dataclass(repr=False) class SystemPromptPart: """A system prompt, generally written by the application developer. This gives the model context and guidance on how to respond. """ content: str """The content of the prompt.""" timestamp: datetime = field(default_factory=_now_utc) """The timestamp of the prompt.""" dynamic_ref: str | None = None """The ref of the dynamic system prompt function that generated this part. Only set if system prompt is dynamic, see [`system_prompt`][pydantic_ai.Agent.system_prompt] for more information. """ part_kind: Literal['system-prompt'] = 'system-prompt' """Part type identifier, this is available on all parts as a discriminator.""" def otel_event(self, settings: InstrumentationSettings) -> Event: return Event( 'gen_ai.system.message', body={'role': 'system', **({'content': self.content} if settings.include_content else {})}, ) __repr__ = _utils.dataclasses_no_defaults_repr ``` #### content ```python content: str ``` The content of the prompt. #### timestamp ```python timestamp: datetime = field(default_factory=now_utc) ``` The timestamp of the prompt. #### dynamic_ref ```python dynamic_ref: str | None = None ``` The ref of the dynamic system prompt function that generated this part. Only set if system prompt is dynamic, see system_prompt for more information. #### part_kind ```python part_kind: Literal['system-prompt'] = 'system-prompt' ``` Part type identifier, this is available on all parts as a discriminator. ### FileUrl Bases: `ABC` Abstract base class for any URL-based file. Source code in `pydantic_ai_slim/pydantic_ai/messages.py` ```python @dataclass(repr=False) class FileUrl(ABC): """Abstract base class for any URL-based file.""" url: str """The URL of the file.""" force_download: bool = False """If the model supports it: * If True, the file is downloaded and the data is sent to the model as bytes. * If False, the URL is sent directly to the model and no download is performed. """ vendor_metadata: dict[str, Any] | None = None """Vendor-specific metadata for the file. Supported by: - `GoogleModel`: `VideoUrl.vendor_metadata` is used as `video_metadata`: https://ai.google.dev/gemini-api/docs/video-understanding#customize-video-processing """ @property @abstractmethod def media_type(self) -> str: """Return the media type of the file, based on the url.""" @property @abstractmethod def format(self) -> str: """The file format.""" __repr__ = _utils.dataclasses_no_defaults_repr ``` #### url ```python url: str ``` The URL of the file. #### force_download ```python force_download: bool = False ``` If the model supports it: - If True, the file is downloaded and the data is sent to the model as bytes. - If False, the URL is sent directly to the model and no download is performed. #### vendor_metadata ```python vendor_metadata: dict[str, Any] | None = None ``` Vendor-specific metadata for the file. Supported by: - `GoogleModel`: `VideoUrl.vendor_metadata` is used as `video_metadata`: https://ai.google.dev/gemini-api/docs/video-understanding#customize-video-processing #### media_type ```python media_type: str ``` Return the media type of the file, based on the url. #### format ```python format: str ``` The file format. ### VideoUrl Bases: `FileUrl` A URL to a video. Source code in `pydantic_ai_slim/pydantic_ai/messages.py` ```python @dataclass(repr=False) class VideoUrl(FileUrl): """A URL to a video.""" url: str """The URL of the video.""" kind: Literal['video-url'] = 'video-url' """Type identifier, this is available on all parts as a discriminator.""" @property def media_type(self) -> VideoMediaType: """Return the media type of the video, based on the url.""" if self.url.endswith('.mkv'): return 'video/x-matroska' elif self.url.endswith('.mov'): return 'video/quicktime' elif self.url.endswith('.mp4'): return 'video/mp4' elif self.url.endswith('.webm'): return 'video/webm' elif self.url.endswith('.flv'): return 'video/x-flv' elif self.url.endswith(('.mpeg', '.mpg')): return 'video/mpeg' elif self.url.endswith('.wmv'): return 'video/x-ms-wmv' elif self.url.endswith('.three_gp'): return 'video/3gpp' # Assume that YouTube videos are mp4 because there would be no extension # to infer from. This should not be a problem, as Gemini disregards media # type for YouTube URLs. elif self.is_youtube: return 'video/mp4' else: raise ValueError(f'Unknown video file extension: {self.url}') @property def is_youtube(self) -> bool: """True if the URL has a YouTube domain.""" return self.url.startswith(('https://youtu.be/', 'https://youtube.com/', 'https://www.youtube.com/')) @property def format(self) -> VideoFormat: """The file format of the video. The choice of supported formats were based on the Bedrock Converse API. Other APIs don't require to use a format. """ return _video_format_lookup[self.media_type] ``` #### url ```python url: str ``` The URL of the video. #### kind ```python kind: Literal['video-url'] = 'video-url' ``` Type identifier, this is available on all parts as a discriminator. #### media_type ```python media_type: VideoMediaType ``` Return the media type of the video, based on the url. #### is_youtube ```python is_youtube: bool ``` True if the URL has a YouTube domain. #### format ```python format: VideoFormat ``` The file format of the video. The choice of supported formats were based on the Bedrock Converse API. Other APIs don't require to use a format. ### AudioUrl Bases: `FileUrl` A URL to an audio file. Source code in `pydantic_ai_slim/pydantic_ai/messages.py` ```python @dataclass(repr=False) class AudioUrl(FileUrl): """A URL to an audio file.""" url: str """The URL of the audio file.""" kind: Literal['audio-url'] = 'audio-url' """Type identifier, this is available on all parts as a discriminator.""" @property def media_type(self) -> AudioMediaType: """Return the media type of the audio file, based on the url.""" if self.url.endswith('.mp3'): return 'audio/mpeg' elif self.url.endswith('.wav'): return 'audio/wav' else: raise ValueError(f'Unknown audio file extension: {self.url}') @property def format(self) -> AudioFormat: """The file format of the audio file.""" return _audio_format_lookup[self.media_type] ``` #### url ```python url: str ``` The URL of the audio file. #### kind ```python kind: Literal['audio-url'] = 'audio-url' ``` Type identifier, this is available on all parts as a discriminator. #### media_type ```python media_type: AudioMediaType ``` Return the media type of the audio file, based on the url. #### format ```python format: AudioFormat ``` The file format of the audio file. ### ImageUrl Bases: `FileUrl` A URL to an image. Source code in `pydantic_ai_slim/pydantic_ai/messages.py` ```python @dataclass(repr=False) class ImageUrl(FileUrl): """A URL to an image.""" url: str """The URL of the image.""" kind: Literal['image-url'] = 'image-url' """Type identifier, this is available on all parts as a discriminator.""" @property def media_type(self) -> ImageMediaType: """Return the media type of the image, based on the url.""" if self.url.endswith(('.jpg', '.jpeg')): return 'image/jpeg' elif self.url.endswith('.png'): return 'image/png' elif self.url.endswith('.gif'): return 'image/gif' elif self.url.endswith('.webp'): return 'image/webp' else: raise ValueError(f'Unknown image file extension: {self.url}') @property def format(self) -> ImageFormat: """The file format of the image. The choice of supported formats were based on the Bedrock Converse API. Other APIs don't require to use a format. """ return _image_format_lookup[self.media_type] ``` #### url ```python url: str ``` The URL of the image. #### kind ```python kind: Literal['image-url'] = 'image-url' ``` Type identifier, this is available on all parts as a discriminator. #### media_type ```python media_type: ImageMediaType ``` Return the media type of the image, based on the url. #### format ```python format: ImageFormat ``` The file format of the image. The choice of supported formats were based on the Bedrock Converse API. Other APIs don't require to use a format. ### DocumentUrl Bases: `FileUrl` The URL of the document. Source code in `pydantic_ai_slim/pydantic_ai/messages.py` ```python @dataclass(repr=False) class DocumentUrl(FileUrl): """The URL of the document.""" url: str """The URL of the document.""" kind: Literal['document-url'] = 'document-url' """Type identifier, this is available on all parts as a discriminator.""" @property def media_type(self) -> str: """Return the media type of the document, based on the url.""" type_, _ = guess_type(self.url) if type_ is None: raise ValueError(f'Unknown document file extension: {self.url}') return type_ @property def format(self) -> DocumentFormat: """The file format of the document. The choice of supported formats were based on the Bedrock Converse API. Other APIs don't require to use a format. """ media_type = self.media_type try: return _document_format_lookup[media_type] except KeyError as e: raise ValueError(f'Unknown document media type: {media_type}') from e ``` #### url ```python url: str ``` The URL of the document. #### kind ```python kind: Literal['document-url'] = 'document-url' ``` Type identifier, this is available on all parts as a discriminator. #### media_type ```python media_type: str ``` Return the media type of the document, based on the url. #### format ```python format: DocumentFormat ``` The file format of the document. The choice of supported formats were based on the Bedrock Converse API. Other APIs don't require to use a format. ### BinaryContent Binary content, e.g. an audio or image file. Source code in `pydantic_ai_slim/pydantic_ai/messages.py` ```python @dataclass(repr=False) class BinaryContent: """Binary content, e.g. an audio or image file.""" data: bytes """The binary data.""" media_type: AudioMediaType | ImageMediaType | DocumentMediaType | str """The media type of the binary data.""" vendor_metadata: dict[str, Any] | None = None """Vendor-specific metadata for the file. Supported by: - `GoogleModel`: `BinaryContent.vendor_metadata` is used as `video_metadata`: https://ai.google.dev/gemini-api/docs/video-understanding#customize-video-processing """ kind: Literal['binary'] = 'binary' """Type identifier, this is available on all parts as a discriminator.""" @property def is_audio(self) -> bool: """Return `True` if the media type is an audio type.""" return self.media_type.startswith('audio/') @property def is_image(self) -> bool: """Return `True` if the media type is an image type.""" return self.media_type.startswith('image/') @property def is_video(self) -> bool: """Return `True` if the media type is a video type.""" return self.media_type.startswith('video/') @property def is_document(self) -> bool: """Return `True` if the media type is a document type.""" return self.media_type in _document_format_lookup @property def format(self) -> str: """The file format of the binary content.""" try: if self.is_audio: return _audio_format_lookup[self.media_type] elif self.is_image: return _image_format_lookup[self.media_type] elif self.is_video: return _video_format_lookup[self.media_type] else: return _document_format_lookup[self.media_type] except KeyError as e: raise ValueError(f'Unknown media type: {self.media_type}') from e __repr__ = _utils.dataclasses_no_defaults_repr ``` #### data ```python data: bytes ``` The binary data. #### media_type ```python media_type: ( AudioMediaType | ImageMediaType | DocumentMediaType | str ) ``` The media type of the binary data. #### vendor_metadata ```python vendor_metadata: dict[str, Any] | None = None ``` Vendor-specific metadata for the file. Supported by: - `GoogleModel`: `BinaryContent.vendor_metadata` is used as `video_metadata`: https://ai.google.dev/gemini-api/docs/video-understanding#customize-video-processing #### kind ```python kind: Literal['binary'] = 'binary' ``` Type identifier, this is available on all parts as a discriminator. #### is_audio ```python is_audio: bool ``` Return `True` if the media type is an audio type. #### is_image ```python is_image: bool ``` Return `True` if the media type is an image type. #### is_video ```python is_video: bool ``` Return `True` if the media type is a video type. #### is_document ```python is_document: bool ``` Return `True` if the media type is a document type. #### format ```python format: str ``` The file format of the binary content. ### ToolReturn A structured return value for tools that need to provide both a return value and custom content to the model. This class allows tools to return complex responses that include: - A return value for actual tool return - Custom content (including multi-modal content) to be sent to the model as a UserPromptPart - Optional metadata for application use Source code in `pydantic_ai_slim/pydantic_ai/messages.py` ```python @dataclass(repr=False) class ToolReturn: """A structured return value for tools that need to provide both a return value and custom content to the model. This class allows tools to return complex responses that include: - A return value for actual tool return - Custom content (including multi-modal content) to be sent to the model as a UserPromptPart - Optional metadata for application use """ return_value: Any """The return value to be used in the tool response.""" content: Sequence[UserContent] | None = None """The content sequence to be sent to the model as a UserPromptPart.""" metadata: Any = None """Additional data that can be accessed programmatically by the application but is not sent to the LLM.""" __repr__ = _utils.dataclasses_no_defaults_repr ``` #### return_value ```python return_value: Any ``` The return value to be used in the tool response. #### content ```python content: Sequence[UserContent] | None = None ``` The content sequence to be sent to the model as a UserPromptPart. #### metadata ```python metadata: Any = None ``` Additional data that can be accessed programmatically by the application but is not sent to the LLM. ### UserPromptPart A user prompt, generally written by the end user. Content comes from the `user_prompt` parameter of Agent.run, Agent.run_sync, and Agent.run_stream. Source code in `pydantic_ai_slim/pydantic_ai/messages.py` ```python @dataclass(repr=False) class UserPromptPart: """A user prompt, generally written by the end user. Content comes from the `user_prompt` parameter of [`Agent.run`][pydantic_ai.Agent.run], [`Agent.run_sync`][pydantic_ai.Agent.run_sync], and [`Agent.run_stream`][pydantic_ai.Agent.run_stream]. """ content: str | Sequence[UserContent] """The content of the prompt.""" timestamp: datetime = field(default_factory=_now_utc) """The timestamp of the prompt.""" part_kind: Literal['user-prompt'] = 'user-prompt' """Part type identifier, this is available on all parts as a discriminator.""" def otel_event(self, settings: InstrumentationSettings) -> Event: content: str | list[dict[str, Any] | str] if isinstance(self.content, str): content = self.content else: content = [] for part in self.content: if isinstance(part, str): content.append(part if settings.include_content else {'kind': 'text'}) elif isinstance(part, (ImageUrl, AudioUrl, DocumentUrl, VideoUrl)): content.append({'kind': part.kind, **({'url': part.url} if settings.include_content else {})}) elif isinstance(part, BinaryContent): converted_part = {'kind': part.kind, 'media_type': part.media_type} if settings.include_content and settings.include_binary_content: converted_part['binary_content'] = base64.b64encode(part.data).decode() content.append(converted_part) else: content.append({'kind': part.kind}) # pragma: no cover return Event('gen_ai.user.message', body={'content': content, 'role': 'user'}) __repr__ = _utils.dataclasses_no_defaults_repr ``` #### content ```python content: str | Sequence[UserContent] ``` The content of the prompt. #### timestamp ```python timestamp: datetime = field(default_factory=now_utc) ``` The timestamp of the prompt. #### part_kind ```python part_kind: Literal['user-prompt'] = 'user-prompt' ``` Part type identifier, this is available on all parts as a discriminator. ### ToolReturnPart A tool return message, this encodes the result of running a tool. Source code in `pydantic_ai_slim/pydantic_ai/messages.py` ```python @dataclass(repr=False) class ToolReturnPart: """A tool return message, this encodes the result of running a tool.""" tool_name: str """The name of the "tool" was called.""" content: Any """The return value.""" tool_call_id: str """The tool call identifier, this is used by some models including OpenAI.""" metadata: Any = None """Additional data that can be accessed programmatically by the application but is not sent to the LLM.""" timestamp: datetime = field(default_factory=_now_utc) """The timestamp, when the tool returned.""" part_kind: Literal['tool-return'] = 'tool-return' """Part type identifier, this is available on all parts as a discriminator.""" def model_response_str(self) -> str: """Return a string representation of the content for the model.""" if isinstance(self.content, str): return self.content else: return tool_return_ta.dump_json(self.content).decode() def model_response_object(self) -> dict[str, Any]: """Return a dictionary representation of the content, wrapping non-dict types appropriately.""" # gemini supports JSON dict return values, but no other JSON types, hence we wrap anything else in a dict if isinstance(self.content, dict): return tool_return_ta.dump_python(self.content, mode='json') # pyright: ignore[reportUnknownMemberType] else: return {'return_value': tool_return_ta.dump_python(self.content, mode='json')} def otel_event(self, settings: InstrumentationSettings) -> Event: return Event( 'gen_ai.tool.message', body={ **({'content': self.content} if settings.include_content else {}), 'role': 'tool', 'id': self.tool_call_id, 'name': self.tool_name, }, ) __repr__ = _utils.dataclasses_no_defaults_repr ``` #### tool_name ```python tool_name: str ``` The name of the "tool" was called. #### content ```python content: Any ``` The return value. #### tool_call_id ```python tool_call_id: str ``` The tool call identifier, this is used by some models including OpenAI. #### metadata ```python metadata: Any = None ``` Additional data that can be accessed programmatically by the application but is not sent to the LLM. #### timestamp ```python timestamp: datetime = field(default_factory=now_utc) ``` The timestamp, when the tool returned. #### part_kind ```python part_kind: Literal['tool-return'] = 'tool-return' ``` Part type identifier, this is available on all parts as a discriminator. #### model_response_str ```python model_response_str() -> str ``` Return a string representation of the content for the model. Source code in `pydantic_ai_slim/pydantic_ai/messages.py` ```python def model_response_str(self) -> str: """Return a string representation of the content for the model.""" if isinstance(self.content, str): return self.content else: return tool_return_ta.dump_json(self.content).decode() ``` #### model_response_object ```python model_response_object() -> dict[str, Any] ``` Return a dictionary representation of the content, wrapping non-dict types appropriately. Source code in `pydantic_ai_slim/pydantic_ai/messages.py` ```python def model_response_object(self) -> dict[str, Any]: """Return a dictionary representation of the content, wrapping non-dict types appropriately.""" # gemini supports JSON dict return values, but no other JSON types, hence we wrap anything else in a dict if isinstance(self.content, dict): return tool_return_ta.dump_python(self.content, mode='json') # pyright: ignore[reportUnknownMemberType] else: return {'return_value': tool_return_ta.dump_python(self.content, mode='json')} ``` ### RetryPromptPart A message back to a model asking it to try again. This can be sent for a number of reasons: - Pydantic validation of tool arguments failed, here content is derived from a Pydantic ValidationError - a tool raised a ModelRetry exception - no tool was found for the tool name - the model returned plain text when a structured response was expected - Pydantic validation of a structured response failed, here content is derived from a Pydantic ValidationError - an output validator raised a ModelRetry exception Source code in `pydantic_ai_slim/pydantic_ai/messages.py` ```python @dataclass(repr=False) class RetryPromptPart: """A message back to a model asking it to try again. This can be sent for a number of reasons: * Pydantic validation of tool arguments failed, here content is derived from a Pydantic [`ValidationError`][pydantic_core.ValidationError] * a tool raised a [`ModelRetry`][pydantic_ai.exceptions.ModelRetry] exception * no tool was found for the tool name * the model returned plain text when a structured response was expected * Pydantic validation of a structured response failed, here content is derived from a Pydantic [`ValidationError`][pydantic_core.ValidationError] * an output validator raised a [`ModelRetry`][pydantic_ai.exceptions.ModelRetry] exception """ content: list[pydantic_core.ErrorDetails] | str """Details of why and how the model should retry. If the retry was triggered by a [`ValidationError`][pydantic_core.ValidationError], this will be a list of error details. """ tool_name: str | None = None """The name of the tool that was called, if any.""" tool_call_id: str = field(default_factory=_generate_tool_call_id) """The tool call identifier, this is used by some models including OpenAI. In case the tool call id is not provided by the model, PydanticAI will generate a random one. """ timestamp: datetime = field(default_factory=_now_utc) """The timestamp, when the retry was triggered.""" part_kind: Literal['retry-prompt'] = 'retry-prompt' """Part type identifier, this is available on all parts as a discriminator.""" def model_response(self) -> str: """Return a string message describing why the retry is requested.""" if isinstance(self.content, str): if self.tool_name is None: description = f'Validation feedback:\n{self.content}' else: description = self.content else: json_errors = error_details_ta.dump_json(self.content, exclude={'__all__': {'ctx'}}, indent=2) description = f'{len(self.content)} validation errors: {json_errors.decode()}' return f'{description}\n\nFix the errors and try again.' def otel_event(self, settings: InstrumentationSettings) -> Event: if self.tool_name is None: return Event('gen_ai.user.message', body={'content': self.model_response(), 'role': 'user'}) else: return Event( 'gen_ai.tool.message', body={ **({'content': self.model_response()} if settings.include_content else {}), 'role': 'tool', 'id': self.tool_call_id, 'name': self.tool_name, }, ) __repr__ = _utils.dataclasses_no_defaults_repr ``` #### content ```python content: list[ErrorDetails] | str ``` Details of why and how the model should retry. If the retry was triggered by a ValidationError, this will be a list of error details. #### tool_name ```python tool_name: str | None = None ``` The name of the tool that was called, if any. #### tool_call_id ```python tool_call_id: str = field( default_factory=generate_tool_call_id ) ``` The tool call identifier, this is used by some models including OpenAI. In case the tool call id is not provided by the model, PydanticAI will generate a random one. #### timestamp ```python timestamp: datetime = field(default_factory=now_utc) ``` The timestamp, when the retry was triggered. #### part_kind ```python part_kind: Literal['retry-prompt'] = 'retry-prompt' ``` Part type identifier, this is available on all parts as a discriminator. #### model_response ```python model_response() -> str ``` Return a string message describing why the retry is requested. Source code in `pydantic_ai_slim/pydantic_ai/messages.py` ```python def model_response(self) -> str: """Return a string message describing why the retry is requested.""" if isinstance(self.content, str): if self.tool_name is None: description = f'Validation feedback:\n{self.content}' else: description = self.content else: json_errors = error_details_ta.dump_json(self.content, exclude={'__all__': {'ctx'}}, indent=2) description = f'{len(self.content)} validation errors: {json_errors.decode()}' return f'{description}\n\nFix the errors and try again.' ``` ### ModelRequestPart ```python ModelRequestPart = Annotated[ Union[ SystemPromptPart, UserPromptPart, ToolReturnPart, RetryPromptPart, ], Discriminator("part_kind"), ] ``` A message part sent by PydanticAI to a model. ### ModelRequest A request generated by PydanticAI and sent to a model, e.g. a message from the PydanticAI app to the model. Source code in `pydantic_ai_slim/pydantic_ai/messages.py` ```python @dataclass(repr=False) class ModelRequest: """A request generated by PydanticAI and sent to a model, e.g. a message from the PydanticAI app to the model.""" parts: list[ModelRequestPart] """The parts of the user message.""" instructions: str | None = None """The instructions for the model.""" kind: Literal['request'] = 'request' """Message type identifier, this is available on all parts as a discriminator.""" @classmethod def user_text_prompt(cls, user_prompt: str, *, instructions: str | None = None) -> ModelRequest: """Create a `ModelRequest` with a single user prompt as text.""" return cls(parts=[UserPromptPart(user_prompt)], instructions=instructions) __repr__ = _utils.dataclasses_no_defaults_repr ``` #### parts ```python parts: list[ModelRequestPart] ``` The parts of the user message. #### instructions ```python instructions: str | None = None ``` The instructions for the model. #### kind ```python kind: Literal['request'] = 'request' ``` Message type identifier, this is available on all parts as a discriminator. #### user_text_prompt ```python user_text_prompt( user_prompt: str, *, instructions: str | None = None ) -> ModelRequest ``` Create a `ModelRequest` with a single user prompt as text. Source code in `pydantic_ai_slim/pydantic_ai/messages.py` ```python @classmethod def user_text_prompt(cls, user_prompt: str, *, instructions: str | None = None) -> ModelRequest: """Create a `ModelRequest` with a single user prompt as text.""" return cls(parts=[UserPromptPart(user_prompt)], instructions=instructions) ``` ### TextPart A plain text response from a model. Source code in `pydantic_ai_slim/pydantic_ai/messages.py` ```python @dataclass(repr=False) class TextPart: """A plain text response from a model.""" content: str """The text content of the response.""" part_kind: Literal['text'] = 'text' """Part type identifier, this is available on all parts as a discriminator.""" def has_content(self) -> bool: """Return `True` if the text content is non-empty.""" return bool(self.content) __repr__ = _utils.dataclasses_no_defaults_repr ``` #### content ```python content: str ``` The text content of the response. #### part_kind ```python part_kind: Literal['text'] = 'text' ``` Part type identifier, this is available on all parts as a discriminator. #### has_content ```python has_content() -> bool ``` Return `True` if the text content is non-empty. Source code in `pydantic_ai_slim/pydantic_ai/messages.py` ```python def has_content(self) -> bool: """Return `True` if the text content is non-empty.""" return bool(self.content) ``` ### ThinkingPart A thinking response from a model. Source code in `pydantic_ai_slim/pydantic_ai/messages.py` ```python @dataclass(repr=False) class ThinkingPart: """A thinking response from a model.""" content: str """The thinking content of the response.""" id: str | None = None """The identifier of the thinking part.""" signature: str | None = None """The signature of the thinking. The signature is only available on the Anthropic models. """ part_kind: Literal['thinking'] = 'thinking' """Part type identifier, this is available on all parts as a discriminator.""" def has_content(self) -> bool: """Return `True` if the thinking content is non-empty.""" return bool(self.content) # pragma: no cover __repr__ = _utils.dataclasses_no_defaults_repr ``` #### content ```python content: str ``` The thinking content of the response. #### id ```python id: str | None = None ``` The identifier of the thinking part. #### signature ```python signature: str | None = None ``` The signature of the thinking. The signature is only available on the Anthropic models. #### part_kind ```python part_kind: Literal['thinking'] = 'thinking' ``` Part type identifier, this is available on all parts as a discriminator. #### has_content ```python has_content() -> bool ``` Return `True` if the thinking content is non-empty. Source code in `pydantic_ai_slim/pydantic_ai/messages.py` ```python def has_content(self) -> bool: """Return `True` if the thinking content is non-empty.""" return bool(self.content) # pragma: no cover ``` ### ToolCallPart A tool call from a model. Source code in `pydantic_ai_slim/pydantic_ai/messages.py` ```python @dataclass(repr=False) class ToolCallPart: """A tool call from a model.""" tool_name: str """The name of the tool to call.""" args: str | dict[str, Any] | None = None """The arguments to pass to the tool. This is stored either as a JSON string or a Python dictionary depending on how data was received. """ tool_call_id: str = field(default_factory=_generate_tool_call_id) """The tool call identifier, this is used by some models including OpenAI. In case the tool call id is not provided by the model, PydanticAI will generate a random one. """ part_kind: Literal['tool-call'] = 'tool-call' """Part type identifier, this is available on all parts as a discriminator.""" def args_as_dict(self) -> dict[str, Any]: """Return the arguments as a Python dictionary. This is just for convenience with models that require dicts as input. """ if not self.args: return {} if isinstance(self.args, dict): return self.args args = pydantic_core.from_json(self.args) assert isinstance(args, dict), 'args should be a dict' return cast(dict[str, Any], args) def args_as_json_str(self) -> str: """Return the arguments as a JSON string. This is just for convenience with models that require JSON strings as input. """ if not self.args: return '{}' if isinstance(self.args, str): return self.args return pydantic_core.to_json(self.args).decode() def has_content(self) -> bool: """Return `True` if the arguments contain any data.""" if isinstance(self.args, dict): # TODO: This should probably return True if you have the value False, or 0, etc. # It makes sense to me to ignore empty strings, but not sure about empty lists or dicts return any(self.args.values()) else: return bool(self.args) __repr__ = _utils.dataclasses_no_defaults_repr ``` #### tool_name ```python tool_name: str ``` The name of the tool to call. #### args ```python args: str | dict[str, Any] | None = None ``` The arguments to pass to the tool. This is stored either as a JSON string or a Python dictionary depending on how data was received. #### tool_call_id ```python tool_call_id: str = field( default_factory=generate_tool_call_id ) ``` The tool call identifier, this is used by some models including OpenAI. In case the tool call id is not provided by the model, PydanticAI will generate a random one. #### part_kind ```python part_kind: Literal['tool-call'] = 'tool-call' ``` Part type identifier, this is available on all parts as a discriminator. #### args_as_dict ```python args_as_dict() -> dict[str, Any] ``` Return the arguments as a Python dictionary. This is just for convenience with models that require dicts as input. Source code in `pydantic_ai_slim/pydantic_ai/messages.py` ```python def args_as_dict(self) -> dict[str, Any]: """Return the arguments as a Python dictionary. This is just for convenience with models that require dicts as input. """ if not self.args: return {} if isinstance(self.args, dict): return self.args args = pydantic_core.from_json(self.args) assert isinstance(args, dict), 'args should be a dict' return cast(dict[str, Any], args) ``` #### args_as_json_str ```python args_as_json_str() -> str ``` Return the arguments as a JSON string. This is just for convenience with models that require JSON strings as input. Source code in `pydantic_ai_slim/pydantic_ai/messages.py` ```python def args_as_json_str(self) -> str: """Return the arguments as a JSON string. This is just for convenience with models that require JSON strings as input. """ if not self.args: return '{}' if isinstance(self.args, str): return self.args return pydantic_core.to_json(self.args).decode() ``` #### has_content ```python has_content() -> bool ``` Return `True` if the arguments contain any data. Source code in `pydantic_ai_slim/pydantic_ai/messages.py` ```python def has_content(self) -> bool: """Return `True` if the arguments contain any data.""" if isinstance(self.args, dict): # TODO: This should probably return True if you have the value False, or 0, etc. # It makes sense to me to ignore empty strings, but not sure about empty lists or dicts return any(self.args.values()) else: return bool(self.args) ``` ### ModelResponsePart ```python ModelResponsePart = Annotated[ Union[TextPart, ToolCallPart, ThinkingPart], Discriminator("part_kind"), ] ``` A message part returned by a model. ### ModelResponse A response from a model, e.g. a message from the model to the PydanticAI app. Source code in `pydantic_ai_slim/pydantic_ai/messages.py` ```python @dataclass(repr=False) class ModelResponse: """A response from a model, e.g. a message from the model to the PydanticAI app.""" parts: list[ModelResponsePart] """The parts of the model message.""" usage: Usage = field(default_factory=Usage) """Usage information for the request. This has a default to make tests easier, and to support loading old messages where usage will be missing. """ model_name: str | None = None """The name of the model that generated the response.""" timestamp: datetime = field(default_factory=_now_utc) """The timestamp of the response. If the model provides a timestamp in the response (as OpenAI does) that will be used. """ kind: Literal['response'] = 'response' """Message type identifier, this is available on all parts as a discriminator.""" vendor_details: dict[str, Any] | None = field(default=None) """Additional vendor-specific details in a serializable format. This allows storing selected vendor-specific data that isn't mapped to standard ModelResponse fields. For OpenAI models, this may include 'logprobs', 'finish_reason', etc. """ vendor_id: str | None = None """Vendor ID as specified by the model provider. This can be used to track the specific request to the model.""" def otel_events(self, settings: InstrumentationSettings) -> list[Event]: """Return OpenTelemetry events for the response.""" result: list[Event] = [] def new_event_body(): new_body: dict[str, Any] = {'role': 'assistant'} ev = Event('gen_ai.assistant.message', body=new_body) result.append(ev) return new_body body = new_event_body() for part in self.parts: if isinstance(part, ToolCallPart): body.setdefault('tool_calls', []).append( { 'id': part.tool_call_id, 'type': 'function', # TODO https://github.com/pydantic/pydantic-ai/issues/888 'function': { 'name': part.tool_name, 'arguments': part.args, }, } ) elif isinstance(part, TextPart): if body.get('content'): body = new_event_body() if settings.include_content: body['content'] = part.content return result __repr__ = _utils.dataclasses_no_defaults_repr ``` #### parts ```python parts: list[ModelResponsePart] ``` The parts of the model message. #### usage ```python usage: Usage = field(default_factory=Usage) ``` Usage information for the request. This has a default to make tests easier, and to support loading old messages where usage will be missing. #### model_name ```python model_name: str | None = None ``` The name of the model that generated the response. #### timestamp ```python timestamp: datetime = field(default_factory=now_utc) ``` The timestamp of the response. If the model provides a timestamp in the response (as OpenAI does) that will be used. #### kind ```python kind: Literal['response'] = 'response' ``` Message type identifier, this is available on all parts as a discriminator. #### vendor_details ```python vendor_details: dict[str, Any] | None = field(default=None) ``` Additional vendor-specific details in a serializable format. This allows storing selected vendor-specific data that isn't mapped to standard ModelResponse fields. For OpenAI models, this may include 'logprobs', 'finish_reason', etc. #### vendor_id ```python vendor_id: str | None = None ``` Vendor ID as specified by the model provider. This can be used to track the specific request to the model. #### otel_events ```python otel_events( settings: InstrumentationSettings, ) -> list[Event] ``` Return OpenTelemetry events for the response. Source code in `pydantic_ai_slim/pydantic_ai/messages.py` ```python def otel_events(self, settings: InstrumentationSettings) -> list[Event]: """Return OpenTelemetry events for the response.""" result: list[Event] = [] def new_event_body(): new_body: dict[str, Any] = {'role': 'assistant'} ev = Event('gen_ai.assistant.message', body=new_body) result.append(ev) return new_body body = new_event_body() for part in self.parts: if isinstance(part, ToolCallPart): body.setdefault('tool_calls', []).append( { 'id': part.tool_call_id, 'type': 'function', # TODO https://github.com/pydantic/pydantic-ai/issues/888 'function': { 'name': part.tool_name, 'arguments': part.args, }, } ) elif isinstance(part, TextPart): if body.get('content'): body = new_event_body() if settings.include_content: body['content'] = part.content return result ``` ### ModelMessage ```python ModelMessage = Annotated[ Union[ModelRequest, ModelResponse], Discriminator("kind"), ] ``` Any message sent to or returned by a model. ### ModelMessagesTypeAdapter ```python ModelMessagesTypeAdapter = TypeAdapter( list[ModelMessage], config=ConfigDict( defer_build=True, ser_json_bytes="base64", val_json_bytes="base64", ), ) ``` Pydantic TypeAdapter for (de)serializing messages. ### TextPartDelta A partial update (delta) for a `TextPart` to append new text content. Source code in `pydantic_ai_slim/pydantic_ai/messages.py` ```python @dataclass(repr=False) class TextPartDelta: """A partial update (delta) for a `TextPart` to append new text content.""" content_delta: str """The incremental text content to add to the existing `TextPart` content.""" part_delta_kind: Literal['text'] = 'text' """Part delta type identifier, used as a discriminator.""" def apply(self, part: ModelResponsePart) -> TextPart: """Apply this text delta to an existing `TextPart`. Args: part: The existing model response part, which must be a `TextPart`. Returns: A new `TextPart` with updated text content. Raises: ValueError: If `part` is not a `TextPart`. """ if not isinstance(part, TextPart): raise ValueError('Cannot apply TextPartDeltas to non-TextParts') # pragma: no cover return replace(part, content=part.content + self.content_delta) __repr__ = _utils.dataclasses_no_defaults_repr ``` #### content_delta ```python content_delta: str ``` The incremental text content to add to the existing `TextPart` content. #### part_delta_kind ```python part_delta_kind: Literal['text'] = 'text' ``` Part delta type identifier, used as a discriminator. #### apply ```python apply(part: ModelResponsePart) -> TextPart ``` Apply this text delta to an existing `TextPart`. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `part` | `ModelResponsePart` | The existing model response part, which must be a TextPart. | *required* | Returns: | Type | Description | | --- | --- | | `TextPart` | A new TextPart with updated text content. | Raises: | Type | Description | | --- | --- | | `ValueError` | If part is not a TextPart. | Source code in `pydantic_ai_slim/pydantic_ai/messages.py` ```python def apply(self, part: ModelResponsePart) -> TextPart: """Apply this text delta to an existing `TextPart`. Args: part: The existing model response part, which must be a `TextPart`. Returns: A new `TextPart` with updated text content. Raises: ValueError: If `part` is not a `TextPart`. """ if not isinstance(part, TextPart): raise ValueError('Cannot apply TextPartDeltas to non-TextParts') # pragma: no cover return replace(part, content=part.content + self.content_delta) ``` ### ThinkingPartDelta A partial update (delta) for a `ThinkingPart` to append new thinking content. Source code in `pydantic_ai_slim/pydantic_ai/messages.py` ```python @dataclass(repr=False) class ThinkingPartDelta: """A partial update (delta) for a `ThinkingPart` to append new thinking content.""" content_delta: str | None = None """The incremental thinking content to add to the existing `ThinkingPart` content.""" signature_delta: str | None = None """Optional signature delta. Note this is never treated as a delta — it can replace None. """ part_delta_kind: Literal['thinking'] = 'thinking' """Part delta type identifier, used as a discriminator.""" @overload def apply(self, part: ModelResponsePart) -> ThinkingPart: ... @overload def apply(self, part: ModelResponsePart | ThinkingPartDelta) -> ThinkingPart | ThinkingPartDelta: ... def apply(self, part: ModelResponsePart | ThinkingPartDelta) -> ThinkingPart | ThinkingPartDelta: """Apply this thinking delta to an existing `ThinkingPart`. Args: part: The existing model response part, which must be a `ThinkingPart`. Returns: A new `ThinkingPart` with updated thinking content. Raises: ValueError: If `part` is not a `ThinkingPart`. """ if isinstance(part, ThinkingPart): new_content = part.content + self.content_delta if self.content_delta else part.content new_signature = self.signature_delta if self.signature_delta is not None else part.signature return replace(part, content=new_content, signature=new_signature) elif isinstance(part, ThinkingPartDelta): if self.content_delta is None and self.signature_delta is None: raise ValueError('Cannot apply ThinkingPartDelta with no content or signature') if self.signature_delta is not None: return replace(part, signature_delta=self.signature_delta) if self.content_delta is not None: return replace(part, content_delta=self.content_delta) raise ValueError( # pragma: no cover f'Cannot apply ThinkingPartDeltas to non-ThinkingParts or non-ThinkingPartDeltas ({part=}, {self=})' ) __repr__ = _utils.dataclasses_no_defaults_repr ``` #### content_delta ```python content_delta: str | None = None ``` The incremental thinking content to add to the existing `ThinkingPart` content. #### signature_delta ```python signature_delta: str | None = None ``` Optional signature delta. Note this is never treated as a delta — it can replace None. #### part_delta_kind ```python part_delta_kind: Literal['thinking'] = 'thinking' ``` Part delta type identifier, used as a discriminator. #### apply ```python apply(part: ModelResponsePart) -> ThinkingPart ``` ```python apply( part: ModelResponsePart | ThinkingPartDelta, ) -> ThinkingPart | ThinkingPartDelta ``` ```python apply( part: ModelResponsePart | ThinkingPartDelta, ) -> ThinkingPart | ThinkingPartDelta ``` Apply this thinking delta to an existing `ThinkingPart`. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `part` | `ModelResponsePart | ThinkingPartDelta` | The existing model response part, which must be a ThinkingPart. | *required* | Returns: | Type | Description | | --- | --- | | `ThinkingPart | ThinkingPartDelta` | A new ThinkingPart with updated thinking content. | Raises: | Type | Description | | --- | --- | | `ValueError` | If part is not a ThinkingPart. | Source code in `pydantic_ai_slim/pydantic_ai/messages.py` ```python def apply(self, part: ModelResponsePart | ThinkingPartDelta) -> ThinkingPart | ThinkingPartDelta: """Apply this thinking delta to an existing `ThinkingPart`. Args: part: The existing model response part, which must be a `ThinkingPart`. Returns: A new `ThinkingPart` with updated thinking content. Raises: ValueError: If `part` is not a `ThinkingPart`. """ if isinstance(part, ThinkingPart): new_content = part.content + self.content_delta if self.content_delta else part.content new_signature = self.signature_delta if self.signature_delta is not None else part.signature return replace(part, content=new_content, signature=new_signature) elif isinstance(part, ThinkingPartDelta): if self.content_delta is None and self.signature_delta is None: raise ValueError('Cannot apply ThinkingPartDelta with no content or signature') if self.signature_delta is not None: return replace(part, signature_delta=self.signature_delta) if self.content_delta is not None: return replace(part, content_delta=self.content_delta) raise ValueError( # pragma: no cover f'Cannot apply ThinkingPartDeltas to non-ThinkingParts or non-ThinkingPartDeltas ({part=}, {self=})' ) ``` ### ToolCallPartDelta A partial update (delta) for a `ToolCallPart` to modify tool name, arguments, or tool call ID. Source code in `pydantic_ai_slim/pydantic_ai/messages.py` ```python @dataclass(repr=False) class ToolCallPartDelta: """A partial update (delta) for a `ToolCallPart` to modify tool name, arguments, or tool call ID.""" tool_name_delta: str | None = None """Incremental text to add to the existing tool name, if any.""" args_delta: str | dict[str, Any] | None = None """Incremental data to add to the tool arguments. If this is a string, it will be appended to existing JSON arguments. If this is a dict, it will be merged with existing dict arguments. """ tool_call_id: str | None = None """Optional tool call identifier, this is used by some models including OpenAI. Note this is never treated as a delta — it can replace None, but otherwise if a non-matching value is provided an error will be raised.""" part_delta_kind: Literal['tool_call'] = 'tool_call' """Part delta type identifier, used as a discriminator.""" def as_part(self) -> ToolCallPart | None: """Convert this delta to a fully formed `ToolCallPart` if possible, otherwise return `None`. Returns: A `ToolCallPart` if `tool_name_delta` is set, otherwise `None`. """ if self.tool_name_delta is None: return None return ToolCallPart(self.tool_name_delta, self.args_delta, self.tool_call_id or _generate_tool_call_id()) @overload def apply(self, part: ModelResponsePart) -> ToolCallPart: ... @overload def apply(self, part: ModelResponsePart | ToolCallPartDelta) -> ToolCallPart | ToolCallPartDelta: ... def apply(self, part: ModelResponsePart | ToolCallPartDelta) -> ToolCallPart | ToolCallPartDelta: """Apply this delta to a part or delta, returning a new part or delta with the changes applied. Args: part: The existing model response part or delta to update. Returns: Either a new `ToolCallPart` or an updated `ToolCallPartDelta`. Raises: ValueError: If `part` is neither a `ToolCallPart` nor a `ToolCallPartDelta`. UnexpectedModelBehavior: If applying JSON deltas to dict arguments or vice versa. """ if isinstance(part, ToolCallPart): return self._apply_to_part(part) if isinstance(part, ToolCallPartDelta): return self._apply_to_delta(part) raise ValueError( # pragma: no cover f'Can only apply ToolCallPartDeltas to ToolCallParts or ToolCallPartDeltas, not {part}' ) def _apply_to_delta(self, delta: ToolCallPartDelta) -> ToolCallPart | ToolCallPartDelta: """Internal helper to apply this delta to another delta.""" if self.tool_name_delta: # Append incremental text to the existing tool_name_delta updated_tool_name_delta = (delta.tool_name_delta or '') + self.tool_name_delta delta = replace(delta, tool_name_delta=updated_tool_name_delta) if isinstance(self.args_delta, str): if isinstance(delta.args_delta, dict): raise UnexpectedModelBehavior( f'Cannot apply JSON deltas to non-JSON tool arguments ({delta=}, {self=})' ) updated_args_delta = (delta.args_delta or '') + self.args_delta delta = replace(delta, args_delta=updated_args_delta) elif isinstance(self.args_delta, dict): if isinstance(delta.args_delta, str): raise UnexpectedModelBehavior( f'Cannot apply dict deltas to non-dict tool arguments ({delta=}, {self=})' ) updated_args_delta = {**(delta.args_delta or {}), **self.args_delta} delta = replace(delta, args_delta=updated_args_delta) if self.tool_call_id: delta = replace(delta, tool_call_id=self.tool_call_id) # If we now have enough data to create a full ToolCallPart, do so if delta.tool_name_delta is not None: return ToolCallPart(delta.tool_name_delta, delta.args_delta, delta.tool_call_id or _generate_tool_call_id()) return delta def _apply_to_part(self, part: ToolCallPart) -> ToolCallPart: """Internal helper to apply this delta directly to a `ToolCallPart`.""" if self.tool_name_delta: # Append incremental text to the existing tool_name tool_name = part.tool_name + self.tool_name_delta part = replace(part, tool_name=tool_name) if isinstance(self.args_delta, str): if isinstance(part.args, dict): raise UnexpectedModelBehavior(f'Cannot apply JSON deltas to non-JSON tool arguments ({part=}, {self=})') updated_json = (part.args or '') + self.args_delta part = replace(part, args=updated_json) elif isinstance(self.args_delta, dict): if isinstance(part.args, str): raise UnexpectedModelBehavior(f'Cannot apply dict deltas to non-dict tool arguments ({part=}, {self=})') updated_dict = {**(part.args or {}), **self.args_delta} part = replace(part, args=updated_dict) if self.tool_call_id: part = replace(part, tool_call_id=self.tool_call_id) return part __repr__ = _utils.dataclasses_no_defaults_repr ``` #### tool_name_delta ```python tool_name_delta: str | None = None ``` Incremental text to add to the existing tool name, if any. #### args_delta ```python args_delta: str | dict[str, Any] | None = None ``` Incremental data to add to the tool arguments. If this is a string, it will be appended to existing JSON arguments. If this is a dict, it will be merged with existing dict arguments. #### tool_call_id ```python tool_call_id: str | None = None ``` Optional tool call identifier, this is used by some models including OpenAI. Note this is never treated as a delta — it can replace None, but otherwise if a non-matching value is provided an error will be raised. #### part_delta_kind ```python part_delta_kind: Literal['tool_call'] = 'tool_call' ``` Part delta type identifier, used as a discriminator. #### as_part ```python as_part() -> ToolCallPart | None ``` Convert this delta to a fully formed `ToolCallPart` if possible, otherwise return `None`. Returns: | Type | Description | | --- | --- | | `ToolCallPart | None` | A ToolCallPart if tool_name_delta is set, otherwise None. | Source code in `pydantic_ai_slim/pydantic_ai/messages.py` ```python def as_part(self) -> ToolCallPart | None: """Convert this delta to a fully formed `ToolCallPart` if possible, otherwise return `None`. Returns: A `ToolCallPart` if `tool_name_delta` is set, otherwise `None`. """ if self.tool_name_delta is None: return None return ToolCallPart(self.tool_name_delta, self.args_delta, self.tool_call_id or _generate_tool_call_id()) ``` #### apply ```python apply(part: ModelResponsePart) -> ToolCallPart ``` ```python apply( part: ModelResponsePart | ToolCallPartDelta, ) -> ToolCallPart | ToolCallPartDelta ``` ```python apply( part: ModelResponsePart | ToolCallPartDelta, ) -> ToolCallPart | ToolCallPartDelta ``` Apply this delta to a part or delta, returning a new part or delta with the changes applied. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `part` | `ModelResponsePart | ToolCallPartDelta` | The existing model response part or delta to update. | *required* | Returns: | Type | Description | | --- | --- | | `ToolCallPart | ToolCallPartDelta` | Either a new ToolCallPart or an updated ToolCallPartDelta. | Raises: | Type | Description | | --- | --- | | `ValueError` | If part is neither a ToolCallPart nor a ToolCallPartDelta. | | `UnexpectedModelBehavior` | If applying JSON deltas to dict arguments or vice versa. | Source code in `pydantic_ai_slim/pydantic_ai/messages.py` ```python def apply(self, part: ModelResponsePart | ToolCallPartDelta) -> ToolCallPart | ToolCallPartDelta: """Apply this delta to a part or delta, returning a new part or delta with the changes applied. Args: part: The existing model response part or delta to update. Returns: Either a new `ToolCallPart` or an updated `ToolCallPartDelta`. Raises: ValueError: If `part` is neither a `ToolCallPart` nor a `ToolCallPartDelta`. UnexpectedModelBehavior: If applying JSON deltas to dict arguments or vice versa. """ if isinstance(part, ToolCallPart): return self._apply_to_part(part) if isinstance(part, ToolCallPartDelta): return self._apply_to_delta(part) raise ValueError( # pragma: no cover f'Can only apply ToolCallPartDeltas to ToolCallParts or ToolCallPartDeltas, not {part}' ) ``` ### ModelResponsePartDelta ```python ModelResponsePartDelta = Annotated[ Union[ TextPartDelta, ThinkingPartDelta, ToolCallPartDelta ], Discriminator("part_delta_kind"), ] ``` A partial update (delta) for any model response part. ### PartStartEvent An event indicating that a new part has started. If multiple `PartStartEvent`s are received with the same index, the new one should fully replace the old one. Source code in `pydantic_ai_slim/pydantic_ai/messages.py` ```python @dataclass(repr=False) class PartStartEvent: """An event indicating that a new part has started. If multiple `PartStartEvent`s are received with the same index, the new one should fully replace the old one. """ index: int """The index of the part within the overall response parts list.""" part: ModelResponsePart """The newly started `ModelResponsePart`.""" event_kind: Literal['part_start'] = 'part_start' """Event type identifier, used as a discriminator.""" __repr__ = _utils.dataclasses_no_defaults_repr ``` #### index ```python index: int ``` The index of the part within the overall response parts list. #### part ```python part: ModelResponsePart ``` The newly started `ModelResponsePart`. #### event_kind ```python event_kind: Literal['part_start'] = 'part_start' ``` Event type identifier, used as a discriminator. ### PartDeltaEvent An event indicating a delta update for an existing part. Source code in `pydantic_ai_slim/pydantic_ai/messages.py` ```python @dataclass(repr=False) class PartDeltaEvent: """An event indicating a delta update for an existing part.""" index: int """The index of the part within the overall response parts list.""" delta: ModelResponsePartDelta """The delta to apply to the specified part.""" event_kind: Literal['part_delta'] = 'part_delta' """Event type identifier, used as a discriminator.""" __repr__ = _utils.dataclasses_no_defaults_repr ``` #### index ```python index: int ``` The index of the part within the overall response parts list. #### delta ```python delta: ModelResponsePartDelta ``` The delta to apply to the specified part. #### event_kind ```python event_kind: Literal['part_delta'] = 'part_delta' ``` Event type identifier, used as a discriminator. ### FinalResultEvent An event indicating the response to the current model request matches the output schema and will produce a result. Source code in `pydantic_ai_slim/pydantic_ai/messages.py` ```python @dataclass(repr=False) class FinalResultEvent: """An event indicating the response to the current model request matches the output schema and will produce a result.""" tool_name: str | None """The name of the output tool that was called. `None` if the result is from text content and not from a tool.""" tool_call_id: str | None """The tool call ID, if any, that this result is associated with.""" event_kind: Literal['final_result'] = 'final_result' """Event type identifier, used as a discriminator.""" __repr__ = _utils.dataclasses_no_defaults_repr ``` #### tool_name ```python tool_name: str | None ``` The name of the output tool that was called. `None` if the result is from text content and not from a tool. #### tool_call_id ```python tool_call_id: str | None ``` The tool call ID, if any, that this result is associated with. #### event_kind ```python event_kind: Literal['final_result'] = 'final_result' ``` Event type identifier, used as a discriminator. ### ModelResponseStreamEvent ```python ModelResponseStreamEvent = Annotated[ Union[PartStartEvent, PartDeltaEvent], Discriminator("event_kind"), ] ``` An event in the model response stream, either starting a new part or applying a delta to an existing one. ### AgentStreamEvent ```python AgentStreamEvent = Annotated[ Union[PartStartEvent, PartDeltaEvent, FinalResultEvent], Discriminator("event_kind"), ] ``` An event in the agent stream. ### FunctionToolCallEvent An event indicating the start to a call to a function tool. Source code in `pydantic_ai_slim/pydantic_ai/messages.py` ```python @dataclass(repr=False) class FunctionToolCallEvent: """An event indicating the start to a call to a function tool.""" part: ToolCallPart """The (function) tool call to make.""" event_kind: Literal['function_tool_call'] = 'function_tool_call' """Event type identifier, used as a discriminator.""" @property def tool_call_id(self) -> str: """An ID used for matching details about the call to its result.""" return self.part.tool_call_id @property @deprecated('`call_id` is deprecated, use `tool_call_id` instead.') def call_id(self) -> str: """An ID used for matching details about the call to its result.""" return self.part.tool_call_id # pragma: no cover __repr__ = _utils.dataclasses_no_defaults_repr ``` #### part ```python part: ToolCallPart ``` The (function) tool call to make. #### event_kind ```python event_kind: Literal["function_tool_call"] = ( "function_tool_call" ) ``` Event type identifier, used as a discriminator. #### tool_call_id ```python tool_call_id: str ``` An ID used for matching details about the call to its result. #### call_id ```python call_id: str ``` An ID used for matching details about the call to its result. ### FunctionToolResultEvent An event indicating the result of a function tool call. Source code in `pydantic_ai_slim/pydantic_ai/messages.py` ```python @dataclass(repr=False) class FunctionToolResultEvent: """An event indicating the result of a function tool call.""" result: ToolReturnPart | RetryPromptPart """The result of the call to the function tool.""" event_kind: Literal['function_tool_result'] = 'function_tool_result' """Event type identifier, used as a discriminator.""" @property def tool_call_id(self) -> str: """An ID used to match the result to its original call.""" return self.result.tool_call_id __repr__ = _utils.dataclasses_no_defaults_repr ``` #### result ```python result: ToolReturnPart | RetryPromptPart ``` The result of the call to the function tool. #### event_kind ```python event_kind: Literal["function_tool_result"] = ( "function_tool_result" ) ``` Event type identifier, used as a discriminator. #### tool_call_id ```python tool_call_id: str ``` An ID used to match the result to its original call. # `pydantic_ai.output` ### OutputDataT ```python OutputDataT = TypeVar( "OutputDataT", default=str, covariant=True ) ``` Covariant type variable for the output data type of a run. ### ToolOutput Bases: `Generic[OutputDataT]` Marker class to use a tool for output and optionally customize the tool. Example: tool_output.py ```python from pydantic import BaseModel from pydantic_ai import Agent, ToolOutput class Fruit(BaseModel): name: str color: str class Vehicle(BaseModel): name: str wheels: int agent = Agent( 'openai:gpt-4o', output_type=[ ToolOutput(Fruit, name='return_fruit'), ToolOutput(Vehicle, name='return_vehicle'), ], ) result = agent.run_sync('What is a banana?') print(repr(result.output)) #> Fruit(name='banana', color='yellow') ``` Source code in `pydantic_ai_slim/pydantic_ai/output.py` ````python @dataclass(init=False) class ToolOutput(Generic[OutputDataT]): """Marker class to use a tool for output and optionally customize the tool. Example: ```python {title="tool_output.py"} from pydantic import BaseModel from pydantic_ai import Agent, ToolOutput class Fruit(BaseModel): name: str color: str class Vehicle(BaseModel): name: str wheels: int agent = Agent( 'openai:gpt-4o', output_type=[ ToolOutput(Fruit, name='return_fruit'), ToolOutput(Vehicle, name='return_vehicle'), ], ) result = agent.run_sync('What is a banana?') print(repr(result.output)) #> Fruit(name='banana', color='yellow') ``` """ output: OutputTypeOrFunction[OutputDataT] """An output type or function.""" name: str | None """The name of the tool that will be passed to the model. If not specified and only one output is provided, `final_result` will be used. If multiple outputs are provided, the name of the output type or function will be added to the tool name.""" description: str | None """The description of the tool that will be passed to the model. If not specified, the docstring of the output type or function will be used.""" max_retries: int | None """The maximum number of retries for the tool.""" strict: bool | None """Whether to use strict mode for the tool.""" def __init__( self, type_: OutputTypeOrFunction[OutputDataT], *, name: str | None = None, description: str | None = None, max_retries: int | None = None, strict: bool | None = None, ): self.output = type_ self.name = name self.description = description self.max_retries = max_retries self.strict = strict ```` #### output ```python output: OutputTypeOrFunction[OutputDataT] = type_ ``` An output type or function. #### name ```python name: str | None = name ``` The name of the tool that will be passed to the model. If not specified and only one output is provided, `final_result` will be used. If multiple outputs are provided, the name of the output type or function will be added to the tool name. #### description ```python description: str | None = description ``` The description of the tool that will be passed to the model. If not specified, the docstring of the output type or function will be used. #### max_retries ```python max_retries: int | None = max_retries ``` The maximum number of retries for the tool. #### strict ```python strict: bool | None = strict ``` Whether to use strict mode for the tool. ### NativeOutput Bases: `Generic[OutputDataT]` Marker class to use the model's native structured outputs functionality for outputs and optionally customize the name and description. Example: native_output.py ```python from tool_output import Fruit, Vehicle from pydantic_ai import Agent, NativeOutput agent = Agent( 'openai:gpt-4o', output_type=NativeOutput( [Fruit, Vehicle], name='Fruit or vehicle', description='Return a fruit or vehicle.' ), ) result = agent.run_sync('What is a Ford Explorer?') print(repr(result.output)) #> Vehicle(name='Ford Explorer', wheels=4) ``` Source code in `pydantic_ai_slim/pydantic_ai/output.py` ````python @dataclass(init=False) class NativeOutput(Generic[OutputDataT]): """Marker class to use the model's native structured outputs functionality for outputs and optionally customize the name and description. Example: ```python {title="native_output.py" requires="tool_output.py"} from tool_output import Fruit, Vehicle from pydantic_ai import Agent, NativeOutput agent = Agent( 'openai:gpt-4o', output_type=NativeOutput( [Fruit, Vehicle], name='Fruit or vehicle', description='Return a fruit or vehicle.' ), ) result = agent.run_sync('What is a Ford Explorer?') print(repr(result.output)) #> Vehicle(name='Ford Explorer', wheels=4) ``` """ outputs: OutputTypeOrFunction[OutputDataT] | Sequence[OutputTypeOrFunction[OutputDataT]] """The output types or functions.""" name: str | None """The name of the structured output that will be passed to the model. If not specified and only one output is provided, the name of the output type or function will be used.""" description: str | None """The description of the structured output that will be passed to the model. If not specified and only one output is provided, the docstring of the output type or function will be used.""" strict: bool | None """Whether to use strict mode for the output, if the model supports it.""" def __init__( self, outputs: OutputTypeOrFunction[OutputDataT] | Sequence[OutputTypeOrFunction[OutputDataT]], *, name: str | None = None, description: str | None = None, strict: bool | None = None, ): self.outputs = outputs self.name = name self.description = description self.strict = strict ```` #### outputs ```python outputs: ( OutputTypeOrFunction[OutputDataT] | Sequence[OutputTypeOrFunction[OutputDataT]] ) = outputs ``` The output types or functions. #### name ```python name: str | None = name ``` The name of the structured output that will be passed to the model. If not specified and only one output is provided, the name of the output type or function will be used. #### description ```python description: str | None = description ``` The description of the structured output that will be passed to the model. If not specified and only one output is provided, the docstring of the output type or function will be used. #### strict ```python strict: bool | None = strict ``` Whether to use strict mode for the output, if the model supports it. ### PromptedOutput Bases: `Generic[OutputDataT]` Marker class to use a prompt to tell the model what to output and optionally customize the prompt. Example: prompted_output.py ```python from pydantic import BaseModel from tool_output import Vehicle from pydantic_ai import Agent, PromptedOutput class Device(BaseModel): name: str kind: str agent = Agent( 'openai:gpt-4o', output_type=PromptedOutput( [Vehicle, Device], name='Vehicle or device', description='Return a vehicle or device.' ), ) result = agent.run_sync('What is a MacBook?') print(repr(result.output)) #> Device(name='MacBook', kind='laptop') agent = Agent( 'openai:gpt-4o', output_type=PromptedOutput( [Vehicle, Device], template='Gimme some JSON: {schema}' ), ) result = agent.run_sync('What is a Ford Explorer?') print(repr(result.output)) #> Vehicle(name='Ford Explorer', wheels=4) ``` Source code in `pydantic_ai_slim/pydantic_ai/output.py` ````python @dataclass(init=False) class PromptedOutput(Generic[OutputDataT]): """Marker class to use a prompt to tell the model what to output and optionally customize the prompt. Example: ```python {title="prompted_output.py" requires="tool_output.py"} from pydantic import BaseModel from tool_output import Vehicle from pydantic_ai import Agent, PromptedOutput class Device(BaseModel): name: str kind: str agent = Agent( 'openai:gpt-4o', output_type=PromptedOutput( [Vehicle, Device], name='Vehicle or device', description='Return a vehicle or device.' ), ) result = agent.run_sync('What is a MacBook?') print(repr(result.output)) #> Device(name='MacBook', kind='laptop') agent = Agent( 'openai:gpt-4o', output_type=PromptedOutput( [Vehicle, Device], template='Gimme some JSON: {schema}' ), ) result = agent.run_sync('What is a Ford Explorer?') print(repr(result.output)) #> Vehicle(name='Ford Explorer', wheels=4) ``` """ outputs: OutputTypeOrFunction[OutputDataT] | Sequence[OutputTypeOrFunction[OutputDataT]] """The output types or functions.""" name: str | None """The name of the structured output that will be passed to the model. If not specified and only one output is provided, the name of the output type or function will be used.""" description: str | None """The description that will be passed to the model. If not specified and only one output is provided, the docstring of the output type or function will be used.""" template: str | None """Template for the prompt passed to the model. The '{schema}' placeholder will be replaced with the output JSON schema. If not specified, the default template specified on the model's profile will be used. """ def __init__( self, outputs: OutputTypeOrFunction[OutputDataT] | Sequence[OutputTypeOrFunction[OutputDataT]], *, name: str | None = None, description: str | None = None, template: str | None = None, ): self.outputs = outputs self.name = name self.description = description self.template = template ```` #### outputs ```python outputs: ( OutputTypeOrFunction[OutputDataT] | Sequence[OutputTypeOrFunction[OutputDataT]] ) = outputs ``` The output types or functions. #### name ```python name: str | None = name ``` The name of the structured output that will be passed to the model. If not specified and only one output is provided, the name of the output type or function will be used. #### description ```python description: str | None = description ``` The description that will be passed to the model. If not specified and only one output is provided, the docstring of the output type or function will be used. #### template ```python template: str | None = template ``` Template for the prompt passed to the model. The '{schema}' placeholder will be replaced with the output JSON schema. If not specified, the default template specified on the model's profile will be used. ### TextOutput Bases: `Generic[OutputDataT]` Marker class to use text output for an output function taking a string argument. Example: ```python from pydantic_ai import Agent, TextOutput def split_into_words(text: str) -> list[str]: return text.split() agent = Agent( 'openai:gpt-4o', output_type=TextOutput(split_into_words), ) result = agent.run_sync('Who was Albert Einstein?') print(result.output) #> ['Albert', 'Einstein', 'was', 'a', 'German-born', 'theoretical', 'physicist.'] ``` Source code in `pydantic_ai_slim/pydantic_ai/output.py` ````python @dataclass class TextOutput(Generic[OutputDataT]): """Marker class to use text output for an output function taking a string argument. Example: ```python from pydantic_ai import Agent, TextOutput def split_into_words(text: str) -> list[str]: return text.split() agent = Agent( 'openai:gpt-4o', output_type=TextOutput(split_into_words), ) result = agent.run_sync('Who was Albert Einstein?') print(result.output) #> ['Albert', 'Einstein', 'was', 'a', 'German-born', 'theoretical', 'physicist.'] ``` """ output_function: TextOutputFunc[OutputDataT] """The function that will be called to process the model's plain text output. The function must take a single string argument.""" ```` #### output_function ```python output_function: TextOutputFunc[OutputDataT] ``` The function that will be called to process the model's plain text output. The function must take a single string argument. # `pydantic_ai.profiles` Describes how requests to a specific model or family of models need to be constructed to get the best results, independent of the model and provider classes used. Source code in `pydantic_ai_slim/pydantic_ai/profiles/__init__.py` ```python @dataclass class ModelProfile: """Describes how requests to a specific model or family of models need to be constructed to get the best results, independent of the model and provider classes used.""" supports_tools: bool = True """Whether the model supports tools.""" supports_json_schema_output: bool = False """Whether the model supports JSON schema output.""" supports_json_object_output: bool = False """Whether the model supports JSON object output.""" default_structured_output_mode: StructuredOutputMode = 'tool' """The default structured output mode to use for the model.""" prompted_output_template: str = dedent( """ Always respond with a JSON object that's compatible with this schema: {schema} Don't include any text or Markdown fencing before or after. """ ) """The instructions template to use for prompted structured output. The '{schema}' placeholder will be replaced with the JSON schema for the output.""" json_schema_transformer: type[JsonSchemaTransformer] | None = None """The transformer to use to make JSON schemas for tools and structured output compatible with the model.""" @classmethod def from_profile(cls, profile: ModelProfile | None) -> Self: """Build a ModelProfile subclass instance from a ModelProfile instance.""" if isinstance(profile, cls): return profile return cls().update(profile) def update(self, profile: ModelProfile | None) -> Self: """Update this ModelProfile (subclass) instance with the non-default values from another ModelProfile instance.""" if not profile: return self field_names = set(f.name for f in fields(self)) non_default_attrs = { f.name: getattr(profile, f.name) for f in fields(profile) if f.name in field_names and getattr(profile, f.name) != f.default } return replace(self, **non_default_attrs) ``` ### supports_tools ```python supports_tools: bool = True ``` Whether the model supports tools. ### supports_json_schema_output ```python supports_json_schema_output: bool = False ``` Whether the model supports JSON schema output. ### supports_json_object_output ```python supports_json_object_output: bool = False ``` Whether the model supports JSON object output. ### default_structured_output_mode ```python default_structured_output_mode: StructuredOutputMode = ( "tool" ) ``` The default structured output mode to use for the model. ### prompted_output_template ```python prompted_output_template: str = dedent( "\n Always respond with a JSON object that's compatible with this schema:\n\n {schema}\n\n Don't include any text or Markdown fencing before or after.\n " ) ``` The instructions template to use for prompted structured output. The '{schema}' placeholder will be replaced with the JSON schema for the output. ### json_schema_transformer ```python json_schema_transformer: ( type[JsonSchemaTransformer] | None ) = None ``` The transformer to use to make JSON schemas for tools and structured output compatible with the model. ### from_profile ```python from_profile(profile: ModelProfile | None) -> Self ``` Build a ModelProfile subclass instance from a ModelProfile instance. Source code in `pydantic_ai_slim/pydantic_ai/profiles/__init__.py` ```python @classmethod def from_profile(cls, profile: ModelProfile | None) -> Self: """Build a ModelProfile subclass instance from a ModelProfile instance.""" if isinstance(profile, cls): return profile return cls().update(profile) ``` ### update ```python update(profile: ModelProfile | None) -> Self ``` Update this ModelProfile (subclass) instance with the non-default values from another ModelProfile instance. Source code in `pydantic_ai_slim/pydantic_ai/profiles/__init__.py` ```python def update(self, profile: ModelProfile | None) -> Self: """Update this ModelProfile (subclass) instance with the non-default values from another ModelProfile instance.""" if not profile: return self field_names = set(f.name for f in fields(self)) non_default_attrs = { f.name: getattr(profile, f.name) for f in fields(profile) if f.name in field_names and getattr(profile, f.name) != f.default } return replace(self, **non_default_attrs) ``` ### OpenAIModelProfile Bases: `ModelProfile` Profile for models used with OpenAIModel. ALL FIELDS MUST BE `openai_` PREFIXED SO YOU CAN MERGE THEM WITH OTHER MODELS. Source code in `pydantic_ai_slim/pydantic_ai/profiles/openai.py` ```python @dataclass class OpenAIModelProfile(ModelProfile): """Profile for models used with OpenAIModel. ALL FIELDS MUST BE `openai_` PREFIXED SO YOU CAN MERGE THEM WITH OTHER MODELS. """ openai_supports_strict_tool_definition: bool = True """This can be set by a provider or user if the OpenAI-"compatible" API doesn't support strict tool definitions.""" openai_supports_sampling_settings: bool = True """Turn off to don't send sampling settings like `temperature` and `top_p` to models that don't support them, like OpenAI's o-series reasoning models.""" ``` #### openai_supports_strict_tool_definition ```python openai_supports_strict_tool_definition: bool = True ``` This can be set by a provider or user if the OpenAI-"compatible" API doesn't support strict tool definitions. #### openai_supports_sampling_settings ```python openai_supports_sampling_settings: bool = True ``` Turn off to don't send sampling settings like `temperature` and `top_p` to models that don't support them, like OpenAI's o-series reasoning models. ### openai_model_profile ```python openai_model_profile(model_name: str) -> ModelProfile ``` Get the model profile for an OpenAI model. Source code in `pydantic_ai_slim/pydantic_ai/profiles/openai.py` ```python def openai_model_profile(model_name: str) -> ModelProfile: """Get the model profile for an OpenAI model.""" is_reasoning_model = model_name.startswith('o') # Structured Outputs (output mode 'native') is only supported with the gpt-4o-mini, gpt-4o-mini-2024-07-18, and gpt-4o-2024-08-06 model snapshots and later. # We leave it in here for all models because the `default_structured_output_mode` is `'tool'`, so `native` is only used # when the user specifically uses the `NativeOutput` marker, so an error from the API is acceptable. return OpenAIModelProfile( json_schema_transformer=OpenAIJsonSchemaTransformer, supports_json_schema_output=True, supports_json_object_output=True, openai_supports_sampling_settings=not is_reasoning_model, ) ``` ### OpenAIJsonSchemaTransformer Bases: `JsonSchemaTransformer` Recursively handle the schema to make it compatible with OpenAI strict mode. See https://platform.openai.com/docs/guides/function-calling?api-mode=responses#strict-mode for more details, but this basically just requires: * `additionalProperties` must be set to false for each object in the parameters * all fields in properties must be marked as required Source code in `pydantic_ai_slim/pydantic_ai/profiles/openai.py` ```python @dataclass class OpenAIJsonSchemaTransformer(JsonSchemaTransformer): """Recursively handle the schema to make it compatible with OpenAI strict mode. See https://platform.openai.com/docs/guides/function-calling?api-mode=responses#strict-mode for more details, but this basically just requires: * `additionalProperties` must be set to false for each object in the parameters * all fields in properties must be marked as required """ def __init__(self, schema: JsonSchema, *, strict: bool | None = None): super().__init__(schema, strict=strict) self.root_ref = schema.get('$ref') def walk(self) -> JsonSchema: # Note: OpenAI does not support anyOf at the root in strict mode # However, we don't need to check for it here because we ensure in pydantic_ai._utils.check_object_json_schema # that the root schema either has type 'object' or is recursive. result = super().walk() # For recursive models, we need to tweak the schema to make it compatible with strict mode. # Because the following should never change the semantics of the schema we apply it unconditionally. if self.root_ref is not None: result.pop('$ref', None) # We replace references to the self.root_ref with just '#' in the transform method root_key = re.sub(r'^#/\$defs/', '', self.root_ref) result.update(self.defs.get(root_key) or {}) return result def transform(self, schema: JsonSchema) -> JsonSchema: # noqa C901 # Remove unnecessary keys schema.pop('title', None) schema.pop('default', None) schema.pop('$schema', None) schema.pop('discriminator', None) if schema_ref := schema.get('$ref'): if schema_ref == self.root_ref: schema['$ref'] = '#' if len(schema) > 1: # OpenAI Strict mode doesn't support siblings to "$ref", but _does_ allow siblings to "anyOf". # So if there is a "description" field or any other extra info, we move the "$ref" into an "anyOf": schema['anyOf'] = [{'$ref': schema.pop('$ref')}] # Track strict-incompatible keys incompatible_values: dict[str, Any] = {} for key in _STRICT_INCOMPATIBLE_KEYS: value = schema.get(key, _sentinel) if value is not _sentinel: incompatible_values[key] = value description = schema.get('description') if incompatible_values: if self.strict is True: notes: list[str] = [] for key, value in incompatible_values.items(): schema.pop(key) notes.append(f'{key}={value}') notes_string = ', '.join(notes) schema['description'] = notes_string if not description else f'{description} ({notes_string})' elif self.strict is None: # pragma: no branch self.is_strict_compatible = False schema_type = schema.get('type') if 'oneOf' in schema: # OpenAI does not support oneOf in strict mode if self.strict is True: schema['anyOf'] = schema.pop('oneOf') else: self.is_strict_compatible = False if schema_type == 'object': if self.strict is True: # additional properties are disallowed schema['additionalProperties'] = False # all properties are required if 'properties' not in schema: schema['properties'] = dict[str, Any]() schema['required'] = list(schema['properties'].keys()) elif self.strict is None: if ( schema.get('additionalProperties') is not False or 'properties' not in schema or 'required' not in schema ): self.is_strict_compatible = False else: required = schema['required'] for k in schema['properties'].keys(): if k not in required: self.is_strict_compatible = False return schema ``` ### anthropic_model_profile ```python anthropic_model_profile( model_name: str, ) -> ModelProfile | None ``` Get the model profile for an Anthropic model. Source code in `pydantic_ai_slim/pydantic_ai/profiles/anthropic.py` ```python def anthropic_model_profile(model_name: str) -> ModelProfile | None: """Get the model profile for an Anthropic model.""" return None ``` ### google_model_profile ```python google_model_profile( model_name: str, ) -> ModelProfile | None ``` Get the model profile for a Google model. Source code in `pydantic_ai_slim/pydantic_ai/profiles/google.py` ```python def google_model_profile(model_name: str) -> ModelProfile | None: """Get the model profile for a Google model.""" return ModelProfile( json_schema_transformer=GoogleJsonSchemaTransformer, supports_json_schema_output=True, supports_json_object_output=True, ) ``` ### GoogleJsonSchemaTransformer Bases: `JsonSchemaTransformer` Transforms the JSON Schema from Pydantic to be suitable for Gemini. Gemini which [supports](https://ai.google.dev/gemini-api/docs/function-calling#function_declarations) a subset of OpenAPI v3.0.3. Specifically: * gemini doesn't allow the `title` keyword to be set * gemini doesn't allow `$defs` — we need to inline the definitions where possible Source code in `pydantic_ai_slim/pydantic_ai/profiles/google.py` ```python class GoogleJsonSchemaTransformer(JsonSchemaTransformer): """Transforms the JSON Schema from Pydantic to be suitable for Gemini. Gemini which [supports](https://ai.google.dev/gemini-api/docs/function-calling#function_declarations) a subset of OpenAPI v3.0.3. Specifically: * gemini doesn't allow the `title` keyword to be set * gemini doesn't allow `$defs` — we need to inline the definitions where possible """ def __init__(self, schema: JsonSchema, *, strict: bool | None = None): super().__init__(schema, strict=strict, prefer_inlined_defs=True, simplify_nullable_unions=True) def transform(self, schema: JsonSchema) -> JsonSchema: # Note: we need to remove `additionalProperties: False` since it is currently mishandled by Gemini additional_properties = schema.pop( 'additionalProperties', None ) # don't pop yet so it's included in the warning if additional_properties: original_schema = {**schema, 'additionalProperties': additional_properties} warnings.warn( '`additionalProperties` is not supported by Gemini; it will be removed from the tool JSON schema.' f' Full schema: {self.schema}\n\n' f'Source of additionalProperties within the full schema: {original_schema}\n\n' 'If this came from a field with a type like `dict[str, MyType]`, that field will always be empty.\n\n' "If Google's APIs are updated to support this properly, please create an issue on the PydanticAI GitHub" ' and we will fix this behavior.', UserWarning, ) schema.pop('title', None) schema.pop('default', None) schema.pop('$schema', None) if (const := schema.pop('const', None)) is not None: # Gemini doesn't support const, but it does support enum with a single value schema['enum'] = [const] schema.pop('discriminator', None) schema.pop('examples', None) # TODO: Should we use the trick from pydantic_ai.models.openai._OpenAIJsonSchema # where we add notes about these properties to the field description? schema.pop('exclusiveMaximum', None) schema.pop('exclusiveMinimum', None) # Gemini only supports string enums, so we need to convert any enum values to strings. # Pydantic will take care of transforming the transformed string values to the correct type. if enum := schema.get('enum'): schema['type'] = 'string' schema['enum'] = [str(val) for val in enum] type_ = schema.get('type') if 'oneOf' in schema and 'type' not in schema: # pragma: no cover # This gets hit when we have a discriminated union # Gemini returns an API error in this case even though it says in its error message it shouldn't... # Changing the oneOf to an anyOf prevents the API error and I think is functionally equivalent schema['anyOf'] = schema.pop('oneOf') if type_ == 'string' and (fmt := schema.pop('format', None)): description = schema.get('description') if description: schema['description'] = f'{description} (format: {fmt})' else: schema['description'] = f'Format: {fmt}' if '$ref' in schema: raise UserError(f'Recursive `$ref`s in JSON Schema are not supported by Gemini: {schema["$ref"]}') if 'prefixItems' in schema: # prefixItems is not currently supported in Gemini, so we convert it to items for best compatibility prefix_items = schema.pop('prefixItems') items = schema.get('items') unique_items = [items] if items is not None else [] for item in prefix_items: if item not in unique_items: unique_items.append(item) if len(unique_items) > 1: # pragma: no cover schema['items'] = {'anyOf': unique_items} elif len(unique_items) == 1: # pragma: no branch schema['items'] = unique_items[0] schema.setdefault('minItems', len(prefix_items)) if items is None: # pragma: no branch schema.setdefault('maxItems', len(prefix_items)) return schema ``` ### meta_model_profile ```python meta_model_profile(model_name: str) -> ModelProfile | None ``` Get the model profile for a Meta model. Source code in `pydantic_ai_slim/pydantic_ai/profiles/meta.py` ```python def meta_model_profile(model_name: str) -> ModelProfile | None: """Get the model profile for a Meta model.""" return ModelProfile(json_schema_transformer=InlineDefsJsonSchemaTransformer) ``` ### amazon_model_profile ```python amazon_model_profile( model_name: str, ) -> ModelProfile | None ``` Get the model profile for an Amazon model. Source code in `pydantic_ai_slim/pydantic_ai/profiles/amazon.py` ```python def amazon_model_profile(model_name: str) -> ModelProfile | None: """Get the model profile for an Amazon model.""" return ModelProfile(json_schema_transformer=InlineDefsJsonSchemaTransformer) ``` ### deepseek_model_profile ```python deepseek_model_profile( model_name: str, ) -> ModelProfile | None ``` Get the model profile for a DeepSeek model. Source code in `pydantic_ai_slim/pydantic_ai/profiles/deepseek.py` ```python def deepseek_model_profile(model_name: str) -> ModelProfile | None: """Get the model profile for a DeepSeek model.""" return None ``` ### grok_model_profile ```python grok_model_profile(model_name: str) -> ModelProfile | None ``` Get the model profile for a Grok model. Source code in `pydantic_ai_slim/pydantic_ai/profiles/grok.py` ```python def grok_model_profile(model_name: str) -> ModelProfile | None: """Get the model profile for a Grok model.""" return None ``` ### mistral_model_profile ```python mistral_model_profile( model_name: str, ) -> ModelProfile | None ``` Get the model profile for a Mistral model. Source code in `pydantic_ai_slim/pydantic_ai/profiles/mistral.py` ```python def mistral_model_profile(model_name: str) -> ModelProfile | None: """Get the model profile for a Mistral model.""" return None ``` ### qwen_model_profile ```python qwen_model_profile(model_name: str) -> ModelProfile | None ``` Get the model profile for a Qwen model. Source code in `pydantic_ai_slim/pydantic_ai/profiles/qwen.py` ```python def qwen_model_profile(model_name: str) -> ModelProfile | None: """Get the model profile for a Qwen model.""" return ModelProfile(json_schema_transformer=InlineDefsJsonSchemaTransformer) ``` # `pydantic_ai.providers` Bases: `ABC`, `Generic[InterfaceClient]` Abstract class for a provider. The provider is in charge of providing an authenticated client to the API. Each provider only supports a specific interface. A interface can be supported by multiple providers. For example, the OpenAIModel interface can be supported by the OpenAIProvider and the DeepSeekProvider. Source code in `pydantic_ai_slim/pydantic_ai/providers/__init__.py` ```python class Provider(ABC, Generic[InterfaceClient]): """Abstract class for a provider. The provider is in charge of providing an authenticated client to the API. Each provider only supports a specific interface. A interface can be supported by multiple providers. For example, the OpenAIModel interface can be supported by the OpenAIProvider and the DeepSeekProvider. """ _client: InterfaceClient @property @abstractmethod def name(self) -> str: """The provider name.""" raise NotImplementedError() @property @abstractmethod def base_url(self) -> str: """The base URL for the provider API.""" raise NotImplementedError() @property @abstractmethod def client(self) -> InterfaceClient: """The client for the provider.""" raise NotImplementedError() def model_profile(self, model_name: str) -> ModelProfile | None: """The model profile for the named model, if available.""" return None # pragma: no cover ``` ### name ```python name: str ``` The provider name. ### base_url ```python base_url: str ``` The base URL for the provider API. ### client ```python client: InterfaceClient ``` The client for the provider. ### model_profile ```python model_profile(model_name: str) -> ModelProfile | None ``` The model profile for the named model, if available. Source code in `pydantic_ai_slim/pydantic_ai/providers/__init__.py` ```python def model_profile(self, model_name: str) -> ModelProfile | None: """The model profile for the named model, if available.""" return None # pragma: no cover ``` ### GoogleProvider Bases: `Provider[Client]` Provider for Google. Source code in `pydantic_ai_slim/pydantic_ai/providers/google.py` ```python class GoogleProvider(Provider[genai.Client]): """Provider for Google.""" @property def name(self) -> str: return 'google-vertex' if self._client._api_client.vertexai else 'google-gla' # type: ignore[reportPrivateUsage] @property def base_url(self) -> str: return str(self._client._api_client._http_options.base_url) # type: ignore[reportPrivateUsage] @property def client(self) -> genai.Client: return self._client def model_profile(self, model_name: str) -> ModelProfile | None: return google_model_profile(model_name) @overload def __init__(self, *, api_key: str) -> None: ... @overload def __init__( self, *, credentials: Credentials | None = None, project: str | None = None, location: VertexAILocation | Literal['global'] | None = None, ) -> None: ... @overload def __init__(self, *, client: genai.Client) -> None: ... @overload def __init__(self, *, vertexai: bool = False) -> None: ... def __init__( self, *, api_key: str | None = None, credentials: Credentials | None = None, project: str | None = None, location: VertexAILocation | Literal['global'] | None = None, client: genai.Client | None = None, vertexai: bool | None = None, ) -> None: """Create a new Google provider. Args: api_key: The `API key `_ to use for authentication. It can also be set via the `GOOGLE_API_KEY` environment variable. Applies to the Gemini Developer API only. credentials: The credentials to use for authentication when calling the Vertex AI APIs. Credentials can be obtained from environment variables and default credentials. For more information, see Set up Application Default Credentials. Applies to the Vertex AI API only. project: The Google Cloud project ID to use for quota. Can be obtained from environment variables (for example, GOOGLE_CLOUD_PROJECT). Applies to the Vertex AI API only. location: The location to send API requests to (for example, us-central1). Can be obtained from environment variables. Applies to the Vertex AI API only. client: A pre-initialized client to use. vertexai: Force the use of the Vertex AI API. If `False`, the Google Generative Language API will be used. Defaults to `False`. """ if client is None: # NOTE: We are keeping GEMINI_API_KEY for backwards compatibility. api_key = api_key or os.getenv('GOOGLE_API_KEY') or os.getenv('GEMINI_API_KEY') if vertexai is None: # pragma: lax no cover vertexai = bool(location or project or credentials) if not vertexai: if api_key is None: raise UserError( # pragma: no cover 'Set the `GOOGLE_API_KEY` environment variable or pass it via `GoogleProvider(api_key=...)`' 'to use the Google Generative Language API.' ) self._client = genai.Client( vertexai=vertexai, api_key=api_key, http_options={'headers': {'User-Agent': get_user_agent()}}, ) else: self._client = genai.Client( vertexai=vertexai, project=project or os.environ.get('GOOGLE_CLOUD_PROJECT'), # From https://github.com/pydantic/pydantic-ai/pull/2031/files#r2169682149: # Currently `us-central1` supports the most models by far of any region including `global`, but not # all of them. `us-central1` has all google models but is missing some Anthropic partner models, # which use `us-east5` instead. `global` has fewer models but higher availability. # For more details, check: https://cloud.google.com/vertex-ai/generative-ai/docs/learn/locations#available-regions location=location or os.environ.get('GOOGLE_CLOUD_LOCATION') or 'us-central1', credentials=credentials, http_options={'headers': {'User-Agent': get_user_agent()}}, ) else: self._client = client # pragma: lax no cover ``` #### __init__ ```python __init__(*, api_key: str) -> None ``` ```python __init__( *, credentials: Credentials | None = None, project: str | None = None, location: ( VertexAILocation | Literal["global"] | None ) = None ) -> None ``` ```python __init__(*, client: Client) -> None ``` ```python __init__(*, vertexai: bool = False) -> None ``` ```python __init__( *, api_key: str | None = None, credentials: Credentials | None = None, project: str | None = None, location: ( VertexAILocation | Literal["global"] | None ) = None, client: Client | None = None, vertexai: bool | None = None ) -> None ``` Create a new Google provider. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `api_key` | `str | None` | The API key \_ to use for authentication. It can also be set via the GOOGLE_API_KEY environment variable. Applies to the Gemini Developer API only. | `None` | | `credentials` | `Credentials | None` | The credentials to use for authentication when calling the Vertex AI APIs. Credentials can be obtained from environment variables and default credentials. For more information, see Set up Application Default Credentials. Applies to the Vertex AI API only. | `None` | | `project` | `str | None` | The Google Cloud project ID to use for quota. Can be obtained from environment variables (for example, GOOGLE_CLOUD_PROJECT). Applies to the Vertex AI API only. | `None` | | `location` | `VertexAILocation | Literal['global'] | None` | The location to send API requests to (for example, us-central1). Can be obtained from environment variables. Applies to the Vertex AI API only. | `None` | | `client` | `Client | None` | A pre-initialized client to use. | `None` | | `vertexai` | `bool | None` | Force the use of the Vertex AI API. If False, the Google Generative Language API will be used. Defaults to False. | `None` | Source code in `pydantic_ai_slim/pydantic_ai/providers/google.py` ```python def __init__( self, *, api_key: str | None = None, credentials: Credentials | None = None, project: str | None = None, location: VertexAILocation | Literal['global'] | None = None, client: genai.Client | None = None, vertexai: bool | None = None, ) -> None: """Create a new Google provider. Args: api_key: The `API key `_ to use for authentication. It can also be set via the `GOOGLE_API_KEY` environment variable. Applies to the Gemini Developer API only. credentials: The credentials to use for authentication when calling the Vertex AI APIs. Credentials can be obtained from environment variables and default credentials. For more information, see Set up Application Default Credentials. Applies to the Vertex AI API only. project: The Google Cloud project ID to use for quota. Can be obtained from environment variables (for example, GOOGLE_CLOUD_PROJECT). Applies to the Vertex AI API only. location: The location to send API requests to (for example, us-central1). Can be obtained from environment variables. Applies to the Vertex AI API only. client: A pre-initialized client to use. vertexai: Force the use of the Vertex AI API. If `False`, the Google Generative Language API will be used. Defaults to `False`. """ if client is None: # NOTE: We are keeping GEMINI_API_KEY for backwards compatibility. api_key = api_key or os.getenv('GOOGLE_API_KEY') or os.getenv('GEMINI_API_KEY') if vertexai is None: # pragma: lax no cover vertexai = bool(location or project or credentials) if not vertexai: if api_key is None: raise UserError( # pragma: no cover 'Set the `GOOGLE_API_KEY` environment variable or pass it via `GoogleProvider(api_key=...)`' 'to use the Google Generative Language API.' ) self._client = genai.Client( vertexai=vertexai, api_key=api_key, http_options={'headers': {'User-Agent': get_user_agent()}}, ) else: self._client = genai.Client( vertexai=vertexai, project=project or os.environ.get('GOOGLE_CLOUD_PROJECT'), # From https://github.com/pydantic/pydantic-ai/pull/2031/files#r2169682149: # Currently `us-central1` supports the most models by far of any region including `global`, but not # all of them. `us-central1` has all google models but is missing some Anthropic partner models, # which use `us-east5` instead. `global` has fewer models but higher availability. # For more details, check: https://cloud.google.com/vertex-ai/generative-ai/docs/learn/locations#available-regions location=location or os.environ.get('GOOGLE_CLOUD_LOCATION') or 'us-central1', credentials=credentials, http_options={'headers': {'User-Agent': get_user_agent()}}, ) else: self._client = client # pragma: lax no cover ``` ### VertexAILocation ```python VertexAILocation = Literal[ "asia-east1", "asia-east2", "asia-northeast1", "asia-northeast3", "asia-south1", "asia-southeast1", "australia-southeast1", "europe-central2", "europe-north1", "europe-southwest1", "europe-west1", "europe-west2", "europe-west3", "europe-west4", "europe-west6", "europe-west8", "europe-west9", "me-central1", "me-central2", "me-west1", "northamerica-northeast1", "southamerica-east1", "us-central1", "us-east1", "us-east4", "us-east5", "us-south1", "us-west1", "us-west4", ] ``` Regions available for Vertex AI. More details [here](https://cloud.google.com/vertex-ai/generative-ai/docs/learn/locations#genai-locations). ### GoogleVertexProvider Bases: `Provider[AsyncClient]` Provider for Vertex AI API. Source code in `pydantic_ai_slim/pydantic_ai/providers/google_vertex.py` ```python class GoogleVertexProvider(Provider[httpx.AsyncClient]): """Provider for Vertex AI API.""" @property def name(self) -> str: return 'google-vertex' @property def base_url(self) -> str: return ( f'https://{self.region}-aiplatform.googleapis.com/v1' f'/projects/{self.project_id}' f'/locations/{self.region}' f'/publishers/{self.model_publisher}/models/' ) @property def client(self) -> httpx.AsyncClient: return self._client def model_profile(self, model_name: str) -> ModelProfile | None: return google_model_profile(model_name) # pragma: lax no cover @overload def __init__( self, *, service_account_file: Path | str | None = None, project_id: str | None = None, region: VertexAiRegion = 'us-central1', model_publisher: str = 'google', http_client: httpx.AsyncClient | None = None, ) -> None: ... @overload def __init__( self, *, service_account_info: Mapping[str, str] | None = None, project_id: str | None = None, region: VertexAiRegion = 'us-central1', model_publisher: str = 'google', http_client: httpx.AsyncClient | None = None, ) -> None: ... def __init__( self, *, service_account_file: Path | str | None = None, service_account_info: Mapping[str, str] | None = None, project_id: str | None = None, region: VertexAiRegion = 'us-central1', model_publisher: str = 'google', http_client: httpx.AsyncClient | None = None, ) -> None: """Create a new Vertex AI provider. Args: service_account_file: Path to a service account file. If not provided, the service_account_info or default environment credentials will be used. service_account_info: The loaded service_account_file contents. If not provided, the service_account_file or default environment credentials will be used. project_id: The project ID to use, if not provided it will be taken from the credentials. region: The region to make requests to. model_publisher: The model publisher to use, I couldn't find a good list of available publishers, and from trial and error it seems non-google models don't work with the `generateContent` and `streamGenerateContent` functions, hence only `google` is currently supported. Please create an issue or PR if you know how to use other publishers. http_client: An existing `httpx.AsyncClient` to use for making HTTP requests. """ if service_account_file and service_account_info: raise ValueError('Only one of `service_account_file` or `service_account_info` can be provided.') self._client = http_client or cached_async_http_client(provider='google-vertex') self.service_account_file = service_account_file self.service_account_info = service_account_info self.project_id = project_id self.region = region self.model_publisher = model_publisher self._client.auth = _VertexAIAuth(service_account_file, service_account_info, project_id, region) self._client.base_url = self.base_url ``` #### __init__ ```python __init__( *, service_account_file: Path | str | None = None, project_id: str | None = None, region: VertexAiRegion = "us-central1", model_publisher: str = "google", http_client: AsyncClient | None = None ) -> None ``` ```python __init__( *, service_account_info: Mapping[str, str] | None = None, project_id: str | None = None, region: VertexAiRegion = "us-central1", model_publisher: str = "google", http_client: AsyncClient | None = None ) -> None ``` ```python __init__( *, service_account_file: Path | str | None = None, service_account_info: Mapping[str, str] | None = None, project_id: str | None = None, region: VertexAiRegion = "us-central1", model_publisher: str = "google", http_client: AsyncClient | None = None ) -> None ``` Create a new Vertex AI provider. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `service_account_file` | `Path | str | None` | Path to a service account file. If not provided, the service_account_info or default environment credentials will be used. | `None` | | `service_account_info` | `Mapping[str, str] | None` | The loaded service_account_file contents. If not provided, the service_account_file or default environment credentials will be used. | `None` | | `project_id` | `str | None` | The project ID to use, if not provided it will be taken from the credentials. | `None` | | `region` | `VertexAiRegion` | The region to make requests to. | `'us-central1'` | | `model_publisher` | `str` | The model publisher to use, I couldn't find a good list of available publishers, and from trial and error it seems non-google models don't work with the generateContent and streamGenerateContent functions, hence only google is currently supported. Please create an issue or PR if you know how to use other publishers. | `'google'` | | `http_client` | `AsyncClient | None` | An existing httpx.AsyncClient to use for making HTTP requests. | `None` | Source code in `pydantic_ai_slim/pydantic_ai/providers/google_vertex.py` ```python def __init__( self, *, service_account_file: Path | str | None = None, service_account_info: Mapping[str, str] | None = None, project_id: str | None = None, region: VertexAiRegion = 'us-central1', model_publisher: str = 'google', http_client: httpx.AsyncClient | None = None, ) -> None: """Create a new Vertex AI provider. Args: service_account_file: Path to a service account file. If not provided, the service_account_info or default environment credentials will be used. service_account_info: The loaded service_account_file contents. If not provided, the service_account_file or default environment credentials will be used. project_id: The project ID to use, if not provided it will be taken from the credentials. region: The region to make requests to. model_publisher: The model publisher to use, I couldn't find a good list of available publishers, and from trial and error it seems non-google models don't work with the `generateContent` and `streamGenerateContent` functions, hence only `google` is currently supported. Please create an issue or PR if you know how to use other publishers. http_client: An existing `httpx.AsyncClient` to use for making HTTP requests. """ if service_account_file and service_account_info: raise ValueError('Only one of `service_account_file` or `service_account_info` can be provided.') self._client = http_client or cached_async_http_client(provider='google-vertex') self.service_account_file = service_account_file self.service_account_info = service_account_info self.project_id = project_id self.region = region self.model_publisher = model_publisher self._client.auth = _VertexAIAuth(service_account_file, service_account_info, project_id, region) self._client.base_url = self.base_url ``` ### OpenAIProvider Bases: `Provider[AsyncOpenAI]` Provider for OpenAI API. Source code in `pydantic_ai_slim/pydantic_ai/providers/openai.py` ```python class OpenAIProvider(Provider[AsyncOpenAI]): """Provider for OpenAI API.""" @property def name(self) -> str: return 'openai' # pragma: no cover @property def base_url(self) -> str: return str(self.client.base_url) @property def client(self) -> AsyncOpenAI: return self._client def model_profile(self, model_name: str) -> ModelProfile | None: return openai_model_profile(model_name) def __init__( self, base_url: str | None = None, api_key: str | None = None, openai_client: AsyncOpenAI | None = None, http_client: httpx.AsyncClient | None = None, ) -> None: """Create a new OpenAI provider. Args: base_url: The base url for the OpenAI requests. If not provided, the `OPENAI_BASE_URL` environment variable will be used if available. Otherwise, defaults to OpenAI's base url. api_key: The API key to use for authentication, if not provided, the `OPENAI_API_KEY` environment variable will be used if available. openai_client: An existing [`AsyncOpenAI`](https://github.com/openai/openai-python?tab=readme-ov-file#async-usage) client to use. If provided, `base_url`, `api_key`, and `http_client` must be `None`. http_client: An existing `httpx.AsyncClient` to use for making HTTP requests. """ # This is a workaround for the OpenAI client requiring an API key, whilst locally served, # openai compatible models do not always need an API key, but a placeholder (non-empty) key is required. if api_key is None and 'OPENAI_API_KEY' not in os.environ and base_url is not None and openai_client is None: api_key = 'api-key-not-set' if openai_client is not None: assert base_url is None, 'Cannot provide both `openai_client` and `base_url`' assert http_client is None, 'Cannot provide both `openai_client` and `http_client`' assert api_key is None, 'Cannot provide both `openai_client` and `api_key`' self._client = openai_client elif http_client is not None: self._client = AsyncOpenAI(base_url=base_url, api_key=api_key, http_client=http_client) else: http_client = cached_async_http_client(provider='openai') self._client = AsyncOpenAI(base_url=base_url, api_key=api_key, http_client=http_client) ``` #### __init__ ```python __init__( base_url: str | None = None, api_key: str | None = None, openai_client: AsyncOpenAI | None = None, http_client: AsyncClient | None = None, ) -> None ``` Create a new OpenAI provider. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `base_url` | `str | None` | The base url for the OpenAI requests. If not provided, the OPENAI_BASE_URL environment variable will be used if available. Otherwise, defaults to OpenAI's base url. | `None` | | `api_key` | `str | None` | The API key to use for authentication, if not provided, the OPENAI_API_KEY environment variable will be used if available. | `None` | | `openai_client` | `AsyncOpenAI | None` | An existing AsyncOpenAI client to use. If provided, base_url, api_key, and http_client must be None. | `None` | | `http_client` | `AsyncClient | None` | An existing httpx.AsyncClient to use for making HTTP requests. | `None` | Source code in `pydantic_ai_slim/pydantic_ai/providers/openai.py` ```python def __init__( self, base_url: str | None = None, api_key: str | None = None, openai_client: AsyncOpenAI | None = None, http_client: httpx.AsyncClient | None = None, ) -> None: """Create a new OpenAI provider. Args: base_url: The base url for the OpenAI requests. If not provided, the `OPENAI_BASE_URL` environment variable will be used if available. Otherwise, defaults to OpenAI's base url. api_key: The API key to use for authentication, if not provided, the `OPENAI_API_KEY` environment variable will be used if available. openai_client: An existing [`AsyncOpenAI`](https://github.com/openai/openai-python?tab=readme-ov-file#async-usage) client to use. If provided, `base_url`, `api_key`, and `http_client` must be `None`. http_client: An existing `httpx.AsyncClient` to use for making HTTP requests. """ # This is a workaround for the OpenAI client requiring an API key, whilst locally served, # openai compatible models do not always need an API key, but a placeholder (non-empty) key is required. if api_key is None and 'OPENAI_API_KEY' not in os.environ and base_url is not None and openai_client is None: api_key = 'api-key-not-set' if openai_client is not None: assert base_url is None, 'Cannot provide both `openai_client` and `base_url`' assert http_client is None, 'Cannot provide both `openai_client` and `http_client`' assert api_key is None, 'Cannot provide both `openai_client` and `api_key`' self._client = openai_client elif http_client is not None: self._client = AsyncOpenAI(base_url=base_url, api_key=api_key, http_client=http_client) else: http_client = cached_async_http_client(provider='openai') self._client = AsyncOpenAI(base_url=base_url, api_key=api_key, http_client=http_client) ``` ### DeepSeekProvider Bases: `Provider[AsyncOpenAI]` Provider for DeepSeek API. Source code in `pydantic_ai_slim/pydantic_ai/providers/deepseek.py` ```python class DeepSeekProvider(Provider[AsyncOpenAI]): """Provider for DeepSeek API.""" @property def name(self) -> str: return 'deepseek' @property def base_url(self) -> str: return 'https://api.deepseek.com' @property def client(self) -> AsyncOpenAI: return self._client def model_profile(self, model_name: str) -> ModelProfile | None: profile = deepseek_model_profile(model_name) # As DeepSeekProvider is always used with OpenAIModel, which used to unconditionally use OpenAIJsonSchemaTransformer, # we need to maintain that behavior unless json_schema_transformer is set explicitly. # This was not the case when using a DeepSeek model with another model class (e.g. BedrockConverseModel or GroqModel), # so we won't do this in `deepseek_model_profile` unless we learn it's always needed. return OpenAIModelProfile(json_schema_transformer=OpenAIJsonSchemaTransformer).update(profile) @overload def __init__(self) -> None: ... @overload def __init__(self, *, api_key: str) -> None: ... @overload def __init__(self, *, api_key: str, http_client: AsyncHTTPClient) -> None: ... @overload def __init__(self, *, openai_client: AsyncOpenAI | None = None) -> None: ... def __init__( self, *, api_key: str | None = None, openai_client: AsyncOpenAI | None = None, http_client: AsyncHTTPClient | None = None, ) -> None: api_key = api_key or os.getenv('DEEPSEEK_API_KEY') if not api_key and openai_client is None: raise UserError( 'Set the `DEEPSEEK_API_KEY` environment variable or pass it via `DeepSeekProvider(api_key=...)`' 'to use the DeepSeek provider.' ) if openai_client is not None: self._client = openai_client elif http_client is not None: self._client = AsyncOpenAI(base_url=self.base_url, api_key=api_key, http_client=http_client) else: http_client = cached_async_http_client(provider='deepseek') self._client = AsyncOpenAI(base_url=self.base_url, api_key=api_key, http_client=http_client) ``` ### BedrockModelProfile Bases: `ModelProfile` Profile for models used with BedrockModel. ALL FIELDS MUST BE `bedrock_` PREFIXED SO YOU CAN MERGE THEM WITH OTHER MODELS. Source code in `pydantic_ai_slim/pydantic_ai/providers/bedrock.py` ```python @dataclass class BedrockModelProfile(ModelProfile): """Profile for models used with BedrockModel. ALL FIELDS MUST BE `bedrock_` PREFIXED SO YOU CAN MERGE THEM WITH OTHER MODELS. """ bedrock_supports_tool_choice: bool = True bedrock_tool_result_format: Literal['text', 'json'] = 'text' ``` ### BedrockProvider Bases: `Provider[BaseClient]` Provider for AWS Bedrock. Source code in `pydantic_ai_slim/pydantic_ai/providers/bedrock.py` ```python class BedrockProvider(Provider[BaseClient]): """Provider for AWS Bedrock.""" @property def name(self) -> str: return 'bedrock' @property def base_url(self) -> str: return self._client.meta.endpoint_url @property def client(self) -> BaseClient: return self._client def model_profile(self, model_name: str) -> ModelProfile | None: provider_to_profile: dict[str, Callable[[str], ModelProfile | None]] = { 'anthropic': lambda model_name: BedrockModelProfile(bedrock_supports_tool_choice=False).update( anthropic_model_profile(model_name) ), 'mistral': lambda model_name: BedrockModelProfile(bedrock_tool_result_format='json').update( mistral_model_profile(model_name) ), 'cohere': cohere_model_profile, 'amazon': amazon_model_profile, 'meta': meta_model_profile, 'deepseek': deepseek_model_profile, } # Split the model name into parts parts = model_name.split('.', 2) # Handle regional prefixes (e.g. "us.") if len(parts) > 2 and len(parts[0]) == 2: parts = parts[1:] if len(parts) < 2: return None provider = parts[0] model_name_with_version = parts[1] # Remove version suffix if it matches the format (e.g. "-v1:0" or "-v14") version_match = re.match(r'(.+)-v\d+(?::\d+)?$', model_name_with_version) if version_match: model_name = version_match.group(1) else: model_name = model_name_with_version if provider in provider_to_profile: return provider_to_profile[provider](model_name) return None @overload def __init__(self, *, bedrock_client: BaseClient) -> None: ... @overload def __init__( self, *, region_name: str | None = None, aws_access_key_id: str | None = None, aws_secret_access_key: str | None = None, aws_session_token: str | None = None, profile_name: str | None = None, aws_read_timeout: float | None = None, aws_connect_timeout: float | None = None, ) -> None: ... def __init__( self, *, bedrock_client: BaseClient | None = None, region_name: str | None = None, aws_access_key_id: str | None = None, aws_secret_access_key: str | None = None, aws_session_token: str | None = None, profile_name: str | None = None, aws_read_timeout: float | None = None, aws_connect_timeout: float | None = None, ) -> None: """Initialize the Bedrock provider. Args: bedrock_client: A boto3 client for Bedrock Runtime. If provided, other arguments are ignored. region_name: The AWS region name. aws_access_key_id: The AWS access key ID. aws_secret_access_key: The AWS secret access key. aws_session_token: The AWS session token. profile_name: The AWS profile name. aws_read_timeout: The read timeout for Bedrock client. aws_connect_timeout: The connect timeout for Bedrock client. """ if bedrock_client is not None: self._client = bedrock_client else: try: read_timeout = aws_read_timeout or float(os.getenv('AWS_READ_TIMEOUT', 300)) connect_timeout = aws_connect_timeout or float(os.getenv('AWS_CONNECT_TIMEOUT', 60)) session = boto3.Session( aws_access_key_id=aws_access_key_id, aws_secret_access_key=aws_secret_access_key, aws_session_token=aws_session_token, region_name=region_name, profile_name=profile_name, ) self._client = session.client( # type: ignore[reportUnknownMemberType] 'bedrock-runtime', config=Config(read_timeout=read_timeout, connect_timeout=connect_timeout), ) except NoRegionError as exc: # pragma: no cover raise UserError('You must provide a `region_name` or a boto3 client for Bedrock Runtime.') from exc ``` #### __init__ ```python __init__(*, bedrock_client: BaseClient) -> None ``` ```python __init__( *, region_name: str | None = None, aws_access_key_id: str | None = None, aws_secret_access_key: str | None = None, aws_session_token: str | None = None, profile_name: str | None = None, aws_read_timeout: float | None = None, aws_connect_timeout: float | None = None ) -> None ``` ```python __init__( *, bedrock_client: BaseClient | None = None, region_name: str | None = None, aws_access_key_id: str | None = None, aws_secret_access_key: str | None = None, aws_session_token: str | None = None, profile_name: str | None = None, aws_read_timeout: float | None = None, aws_connect_timeout: float | None = None ) -> None ``` Initialize the Bedrock provider. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `bedrock_client` | `BaseClient | None` | A boto3 client for Bedrock Runtime. If provided, other arguments are ignored. | `None` | | `region_name` | `str | None` | The AWS region name. | `None` | | `aws_access_key_id` | `str | None` | The AWS access key ID. | `None` | | `aws_secret_access_key` | `str | None` | The AWS secret access key. | `None` | | `aws_session_token` | `str | None` | The AWS session token. | `None` | | `profile_name` | `str | None` | The AWS profile name. | `None` | | `aws_read_timeout` | `float | None` | The read timeout for Bedrock client. | `None` | | `aws_connect_timeout` | `float | None` | The connect timeout for Bedrock client. | `None` | Source code in `pydantic_ai_slim/pydantic_ai/providers/bedrock.py` ```python def __init__( self, *, bedrock_client: BaseClient | None = None, region_name: str | None = None, aws_access_key_id: str | None = None, aws_secret_access_key: str | None = None, aws_session_token: str | None = None, profile_name: str | None = None, aws_read_timeout: float | None = None, aws_connect_timeout: float | None = None, ) -> None: """Initialize the Bedrock provider. Args: bedrock_client: A boto3 client for Bedrock Runtime. If provided, other arguments are ignored. region_name: The AWS region name. aws_access_key_id: The AWS access key ID. aws_secret_access_key: The AWS secret access key. aws_session_token: The AWS session token. profile_name: The AWS profile name. aws_read_timeout: The read timeout for Bedrock client. aws_connect_timeout: The connect timeout for Bedrock client. """ if bedrock_client is not None: self._client = bedrock_client else: try: read_timeout = aws_read_timeout or float(os.getenv('AWS_READ_TIMEOUT', 300)) connect_timeout = aws_connect_timeout or float(os.getenv('AWS_CONNECT_TIMEOUT', 60)) session = boto3.Session( aws_access_key_id=aws_access_key_id, aws_secret_access_key=aws_secret_access_key, aws_session_token=aws_session_token, region_name=region_name, profile_name=profile_name, ) self._client = session.client( # type: ignore[reportUnknownMemberType] 'bedrock-runtime', config=Config(read_timeout=read_timeout, connect_timeout=connect_timeout), ) except NoRegionError as exc: # pragma: no cover raise UserError('You must provide a `region_name` or a boto3 client for Bedrock Runtime.') from exc ``` ### GroqProvider Bases: `Provider[AsyncGroq]` Provider for Groq API. Source code in `pydantic_ai_slim/pydantic_ai/providers/groq.py` ```python class GroqProvider(Provider[AsyncGroq]): """Provider for Groq API.""" @property def name(self) -> str: return 'groq' @property def base_url(self) -> str: return os.environ.get('GROQ_BASE_URL', 'https://api.groq.com') @property def client(self) -> AsyncGroq: return self._client def model_profile(self, model_name: str) -> ModelProfile | None: prefix_to_profile = { 'llama': meta_model_profile, 'meta-llama/': meta_model_profile, 'gemma': google_model_profile, 'qwen': qwen_model_profile, 'deepseek': deepseek_model_profile, 'mistral': mistral_model_profile, } for prefix, profile_func in prefix_to_profile.items(): model_name = model_name.lower() if model_name.startswith(prefix): if prefix.endswith('/'): model_name = model_name[len(prefix) :] return profile_func(model_name) return None @overload def __init__(self, *, groq_client: AsyncGroq | None = None) -> None: ... @overload def __init__(self, *, api_key: str | None = None, http_client: AsyncHTTPClient | None = None) -> None: ... def __init__( self, *, api_key: str | None = None, groq_client: AsyncGroq | None = None, http_client: AsyncHTTPClient | None = None, ) -> None: """Create a new Groq provider. Args: api_key: The API key to use for authentication, if not provided, the `GROQ_API_KEY` environment variable will be used if available. groq_client: An existing [`AsyncGroq`](https://github.com/groq/groq-python?tab=readme-ov-file#async-usage) client to use. If provided, `api_key` and `http_client` must be `None`. http_client: An existing `AsyncHTTPClient` to use for making HTTP requests. """ if groq_client is not None: assert http_client is None, 'Cannot provide both `groq_client` and `http_client`' assert api_key is None, 'Cannot provide both `groq_client` and `api_key`' self._client = groq_client else: api_key = api_key or os.environ.get('GROQ_API_KEY') if not api_key: raise UserError( 'Set the `GROQ_API_KEY` environment variable or pass it via `GroqProvider(api_key=...)`' 'to use the Groq provider.' ) elif http_client is not None: self._client = AsyncGroq(base_url=self.base_url, api_key=api_key, http_client=http_client) else: http_client = cached_async_http_client(provider='groq') self._client = AsyncGroq(base_url=self.base_url, api_key=api_key, http_client=http_client) ``` #### __init__ ```python __init__(*, groq_client: AsyncGroq | None = None) -> None ``` ```python __init__( *, api_key: str | None = None, http_client: AsyncClient | None = None ) -> None ``` ```python __init__( *, api_key: str | None = None, groq_client: AsyncGroq | None = None, http_client: AsyncClient | None = None ) -> None ``` Create a new Groq provider. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `api_key` | `str | None` | The API key to use for authentication, if not provided, the GROQ_API_KEY environment variable will be used if available. | `None` | | `groq_client` | `AsyncGroq | None` | An existing AsyncGroq client to use. If provided, api_key and http_client must be None. | `None` | | `http_client` | `AsyncClient | None` | An existing AsyncHTTPClient to use for making HTTP requests. | `None` | Source code in `pydantic_ai_slim/pydantic_ai/providers/groq.py` ```python def __init__( self, *, api_key: str | None = None, groq_client: AsyncGroq | None = None, http_client: AsyncHTTPClient | None = None, ) -> None: """Create a new Groq provider. Args: api_key: The API key to use for authentication, if not provided, the `GROQ_API_KEY` environment variable will be used if available. groq_client: An existing [`AsyncGroq`](https://github.com/groq/groq-python?tab=readme-ov-file#async-usage) client to use. If provided, `api_key` and `http_client` must be `None`. http_client: An existing `AsyncHTTPClient` to use for making HTTP requests. """ if groq_client is not None: assert http_client is None, 'Cannot provide both `groq_client` and `http_client`' assert api_key is None, 'Cannot provide both `groq_client` and `api_key`' self._client = groq_client else: api_key = api_key or os.environ.get('GROQ_API_KEY') if not api_key: raise UserError( 'Set the `GROQ_API_KEY` environment variable or pass it via `GroqProvider(api_key=...)`' 'to use the Groq provider.' ) elif http_client is not None: self._client = AsyncGroq(base_url=self.base_url, api_key=api_key, http_client=http_client) else: http_client = cached_async_http_client(provider='groq') self._client = AsyncGroq(base_url=self.base_url, api_key=api_key, http_client=http_client) ``` ### AzureProvider Bases: `Provider[AsyncOpenAI]` Provider for Azure OpenAI API. See for more information. Source code in `pydantic_ai_slim/pydantic_ai/providers/azure.py` ```python class AzureProvider(Provider[AsyncOpenAI]): """Provider for Azure OpenAI API. See for more information. """ @property def name(self) -> str: return 'azure' @property def base_url(self) -> str: assert self._base_url is not None return self._base_url @property def client(self) -> AsyncOpenAI: return self._client def model_profile(self, model_name: str) -> ModelProfile | None: model_name = model_name.lower() prefix_to_profile = { 'llama': meta_model_profile, 'meta-': meta_model_profile, 'deepseek': deepseek_model_profile, 'mistralai-': mistral_model_profile, 'mistral': mistral_model_profile, 'cohere-': cohere_model_profile, 'grok': grok_model_profile, } for prefix, profile_func in prefix_to_profile.items(): if model_name.startswith(prefix): if prefix.endswith('-'): model_name = model_name[len(prefix) :] profile = profile_func(model_name) # As AzureProvider is always used with OpenAIModel, which used to unconditionally use OpenAIJsonSchemaTransformer, # we need to maintain that behavior unless json_schema_transformer is set explicitly return OpenAIModelProfile(json_schema_transformer=OpenAIJsonSchemaTransformer).update(profile) # OpenAI models are unprefixed return openai_model_profile(model_name) @overload def __init__(self, *, openai_client: AsyncAzureOpenAI) -> None: ... @overload def __init__( self, *, azure_endpoint: str | None = None, api_version: str | None = None, api_key: str | None = None, http_client: httpx.AsyncClient | None = None, ) -> None: ... def __init__( self, *, azure_endpoint: str | None = None, api_version: str | None = None, api_key: str | None = None, openai_client: AsyncAzureOpenAI | None = None, http_client: httpx.AsyncClient | None = None, ) -> None: """Create a new Azure provider. Args: azure_endpoint: The Azure endpoint to use for authentication, if not provided, the `AZURE_OPENAI_ENDPOINT` environment variable will be used if available. api_version: The API version to use for authentication, if not provided, the `OPENAI_API_VERSION` environment variable will be used if available. api_key: The API key to use for authentication, if not provided, the `AZURE_OPENAI_API_KEY` environment variable will be used if available. openai_client: An existing [`AsyncAzureOpenAI`](https://github.com/openai/openai-python#microsoft-azure-openai) client to use. If provided, `base_url`, `api_key`, and `http_client` must be `None`. http_client: An existing `httpx.AsyncClient` to use for making HTTP requests. """ if openai_client is not None: assert azure_endpoint is None, 'Cannot provide both `openai_client` and `azure_endpoint`' assert http_client is None, 'Cannot provide both `openai_client` and `http_client`' assert api_key is None, 'Cannot provide both `openai_client` and `api_key`' self._base_url = str(openai_client.base_url) self._client = openai_client else: azure_endpoint = azure_endpoint or os.getenv('AZURE_OPENAI_ENDPOINT') if not azure_endpoint: raise UserError( 'Must provide one of the `azure_endpoint` argument or the `AZURE_OPENAI_ENDPOINT` environment variable' ) if not api_key and 'AZURE_OPENAI_API_KEY' not in os.environ: # pragma: no cover raise UserError( 'Must provide one of the `api_key` argument or the `AZURE_OPENAI_API_KEY` environment variable' ) if not api_version and 'OPENAI_API_VERSION' not in os.environ: # pragma: no cover raise UserError( 'Must provide one of the `api_version` argument or the `OPENAI_API_VERSION` environment variable' ) http_client = http_client or cached_async_http_client(provider='azure') self._client = AsyncAzureOpenAI( azure_endpoint=azure_endpoint, api_key=api_key, api_version=api_version, http_client=http_client, ) self._base_url = str(self._client.base_url) ``` #### __init__ ```python __init__(*, openai_client: AsyncAzureOpenAI) -> None ``` ```python __init__( *, azure_endpoint: str | None = None, api_version: str | None = None, api_key: str | None = None, http_client: AsyncClient | None = None ) -> None ``` ```python __init__( *, azure_endpoint: str | None = None, api_version: str | None = None, api_key: str | None = None, openai_client: AsyncAzureOpenAI | None = None, http_client: AsyncClient | None = None ) -> None ``` Create a new Azure provider. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `azure_endpoint` | `str | None` | The Azure endpoint to use for authentication, if not provided, the AZURE_OPENAI_ENDPOINT environment variable will be used if available. | `None` | | `api_version` | `str | None` | The API version to use for authentication, if not provided, the OPENAI_API_VERSION environment variable will be used if available. | `None` | | `api_key` | `str | None` | The API key to use for authentication, if not provided, the AZURE_OPENAI_API_KEY environment variable will be used if available. | `None` | | `openai_client` | `AsyncAzureOpenAI | None` | An existing AsyncAzureOpenAI client to use. If provided, base_url, api_key, and http_client must be None. | `None` | | `http_client` | `AsyncClient | None` | An existing httpx.AsyncClient to use for making HTTP requests. | `None` | Source code in `pydantic_ai_slim/pydantic_ai/providers/azure.py` ```python def __init__( self, *, azure_endpoint: str | None = None, api_version: str | None = None, api_key: str | None = None, openai_client: AsyncAzureOpenAI | None = None, http_client: httpx.AsyncClient | None = None, ) -> None: """Create a new Azure provider. Args: azure_endpoint: The Azure endpoint to use for authentication, if not provided, the `AZURE_OPENAI_ENDPOINT` environment variable will be used if available. api_version: The API version to use for authentication, if not provided, the `OPENAI_API_VERSION` environment variable will be used if available. api_key: The API key to use for authentication, if not provided, the `AZURE_OPENAI_API_KEY` environment variable will be used if available. openai_client: An existing [`AsyncAzureOpenAI`](https://github.com/openai/openai-python#microsoft-azure-openai) client to use. If provided, `base_url`, `api_key`, and `http_client` must be `None`. http_client: An existing `httpx.AsyncClient` to use for making HTTP requests. """ if openai_client is not None: assert azure_endpoint is None, 'Cannot provide both `openai_client` and `azure_endpoint`' assert http_client is None, 'Cannot provide both `openai_client` and `http_client`' assert api_key is None, 'Cannot provide both `openai_client` and `api_key`' self._base_url = str(openai_client.base_url) self._client = openai_client else: azure_endpoint = azure_endpoint or os.getenv('AZURE_OPENAI_ENDPOINT') if not azure_endpoint: raise UserError( 'Must provide one of the `azure_endpoint` argument or the `AZURE_OPENAI_ENDPOINT` environment variable' ) if not api_key and 'AZURE_OPENAI_API_KEY' not in os.environ: # pragma: no cover raise UserError( 'Must provide one of the `api_key` argument or the `AZURE_OPENAI_API_KEY` environment variable' ) if not api_version and 'OPENAI_API_VERSION' not in os.environ: # pragma: no cover raise UserError( 'Must provide one of the `api_version` argument or the `OPENAI_API_VERSION` environment variable' ) http_client = http_client or cached_async_http_client(provider='azure') self._client = AsyncAzureOpenAI( azure_endpoint=azure_endpoint, api_key=api_key, api_version=api_version, http_client=http_client, ) self._base_url = str(self._client.base_url) ``` ### CohereProvider Bases: `Provider[AsyncClientV2]` Provider for Cohere API. Source code in `pydantic_ai_slim/pydantic_ai/providers/cohere.py` ```python class CohereProvider(Provider[AsyncClientV2]): """Provider for Cohere API.""" @property def name(self) -> str: return 'cohere' @property def base_url(self) -> str: client_wrapper = self.client._client_wrapper # type: ignore return str(client_wrapper.get_base_url()) @property def client(self) -> AsyncClientV2: return self._client def model_profile(self, model_name: str) -> ModelProfile | None: return cohere_model_profile(model_name) def __init__( self, *, api_key: str | None = None, cohere_client: AsyncClientV2 | None = None, http_client: AsyncHTTPClient | None = None, ) -> None: """Create a new Cohere provider. Args: api_key: The API key to use for authentication, if not provided, the `CO_API_KEY` environment variable will be used if available. cohere_client: An existing [AsyncClientV2](https://github.com/cohere-ai/cohere-python) client to use. If provided, `api_key` and `http_client` must be `None`. http_client: An existing `httpx.AsyncClient` to use for making HTTP requests. """ if cohere_client is not None: assert http_client is None, 'Cannot provide both `cohere_client` and `http_client`' assert api_key is None, 'Cannot provide both `cohere_client` and `api_key`' self._client = cohere_client else: api_key = api_key or os.environ.get('CO_API_KEY') if not api_key: raise UserError( 'Set the `CO_API_KEY` environment variable or pass it via `CohereProvider(api_key=...)`' 'to use the Cohere provider.' ) base_url = os.environ.get('CO_BASE_URL') if http_client is not None: self._client = AsyncClientV2(api_key=api_key, httpx_client=http_client, base_url=base_url) else: http_client = cached_async_http_client(provider='cohere') self._client = AsyncClientV2(api_key=api_key, httpx_client=http_client, base_url=base_url) ``` #### __init__ ```python __init__( *, api_key: str | None = None, cohere_client: AsyncClientV2 | None = None, http_client: AsyncClient | None = None ) -> None ``` Create a new Cohere provider. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `api_key` | `str | None` | The API key to use for authentication, if not provided, the CO_API_KEY environment variable will be used if available. | `None` | | `cohere_client` | `AsyncClientV2 | None` | An existing AsyncClientV2 client to use. If provided, api_key and http_client must be None. | `None` | | `http_client` | `AsyncClient | None` | An existing httpx.AsyncClient to use for making HTTP requests. | `None` | Source code in `pydantic_ai_slim/pydantic_ai/providers/cohere.py` ```python def __init__( self, *, api_key: str | None = None, cohere_client: AsyncClientV2 | None = None, http_client: AsyncHTTPClient | None = None, ) -> None: """Create a new Cohere provider. Args: api_key: The API key to use for authentication, if not provided, the `CO_API_KEY` environment variable will be used if available. cohere_client: An existing [AsyncClientV2](https://github.com/cohere-ai/cohere-python) client to use. If provided, `api_key` and `http_client` must be `None`. http_client: An existing `httpx.AsyncClient` to use for making HTTP requests. """ if cohere_client is not None: assert http_client is None, 'Cannot provide both `cohere_client` and `http_client`' assert api_key is None, 'Cannot provide both `cohere_client` and `api_key`' self._client = cohere_client else: api_key = api_key or os.environ.get('CO_API_KEY') if not api_key: raise UserError( 'Set the `CO_API_KEY` environment variable or pass it via `CohereProvider(api_key=...)`' 'to use the Cohere provider.' ) base_url = os.environ.get('CO_BASE_URL') if http_client is not None: self._client = AsyncClientV2(api_key=api_key, httpx_client=http_client, base_url=base_url) else: http_client = cached_async_http_client(provider='cohere') self._client = AsyncClientV2(api_key=api_key, httpx_client=http_client, base_url=base_url) ``` Bases: `Provider[Mistral]` Provider for Mistral API. Source code in `pydantic_ai_slim/pydantic_ai/providers/mistral.py` ```python class MistralProvider(Provider[Mistral]): """Provider for Mistral API.""" @property def name(self) -> str: return 'mistral' @property def base_url(self) -> str: return self.client.sdk_configuration.get_server_details()[0] @property def client(self) -> Mistral: return self._client def model_profile(self, model_name: str) -> ModelProfile | None: return mistral_model_profile(model_name) @overload def __init__(self, *, mistral_client: Mistral | None = None) -> None: ... @overload def __init__(self, *, api_key: str | None = None, http_client: AsyncHTTPClient | None = None) -> None: ... def __init__( self, *, api_key: str | None = None, mistral_client: Mistral | None = None, base_url: str | None = None, http_client: AsyncHTTPClient | None = None, ) -> None: """Create a new Mistral provider. Args: api_key: The API key to use for authentication, if not provided, the `MISTRAL_API_KEY` environment variable will be used if available. mistral_client: An existing `Mistral` client to use, if provided, `api_key` and `http_client` must be `None`. base_url: The base url for the Mistral requests. http_client: An existing async client to use for making HTTP requests. """ if mistral_client is not None: assert http_client is None, 'Cannot provide both `mistral_client` and `http_client`' assert api_key is None, 'Cannot provide both `mistral_client` and `api_key`' assert base_url is None, 'Cannot provide both `mistral_client` and `base_url`' self._client = mistral_client else: api_key = api_key or os.environ.get('MISTRAL_API_KEY') if not api_key: raise UserError( 'Set the `MISTRAL_API_KEY` environment variable or pass it via `MistralProvider(api_key=...)`' 'to use the Mistral provider.' ) elif http_client is not None: self._client = Mistral(api_key=api_key, async_client=http_client, server_url=base_url) else: http_client = cached_async_http_client(provider='mistral') self._client = Mistral(api_key=api_key, async_client=http_client, server_url=base_url) ``` ### __init__ ```python __init__(*, mistral_client: Mistral | None = None) -> None ``` ```python __init__( *, api_key: str | None = None, http_client: AsyncClient | None = None ) -> None ``` ```python __init__( *, api_key: str | None = None, mistral_client: Mistral | None = None, base_url: str | None = None, http_client: AsyncClient | None = None ) -> None ``` Create a new Mistral provider. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `api_key` | `str | None` | The API key to use for authentication, if not provided, the MISTRAL_API_KEY environment variable will be used if available. | `None` | | `mistral_client` | `Mistral | None` | An existing Mistral client to use, if provided, api_key and http_client must be None. | `None` | | `base_url` | `str | None` | The base url for the Mistral requests. | `None` | | `http_client` | `AsyncClient | None` | An existing async client to use for making HTTP requests. | `None` | Source code in `pydantic_ai_slim/pydantic_ai/providers/mistral.py` ```python def __init__( self, *, api_key: str | None = None, mistral_client: Mistral | None = None, base_url: str | None = None, http_client: AsyncHTTPClient | None = None, ) -> None: """Create a new Mistral provider. Args: api_key: The API key to use for authentication, if not provided, the `MISTRAL_API_KEY` environment variable will be used if available. mistral_client: An existing `Mistral` client to use, if provided, `api_key` and `http_client` must be `None`. base_url: The base url for the Mistral requests. http_client: An existing async client to use for making HTTP requests. """ if mistral_client is not None: assert http_client is None, 'Cannot provide both `mistral_client` and `http_client`' assert api_key is None, 'Cannot provide both `mistral_client` and `api_key`' assert base_url is None, 'Cannot provide both `mistral_client` and `base_url`' self._client = mistral_client else: api_key = api_key or os.environ.get('MISTRAL_API_KEY') if not api_key: raise UserError( 'Set the `MISTRAL_API_KEY` environment variable or pass it via `MistralProvider(api_key=...)`' 'to use the Mistral provider.' ) elif http_client is not None: self._client = Mistral(api_key=api_key, async_client=http_client, server_url=base_url) else: http_client = cached_async_http_client(provider='mistral') self._client = Mistral(api_key=api_key, async_client=http_client, server_url=base_url) ``` Bases: `Provider[AsyncOpenAI]` Provider for Fireworks AI API. Source code in `pydantic_ai_slim/pydantic_ai/providers/fireworks.py` ```python class FireworksProvider(Provider[AsyncOpenAI]): """Provider for Fireworks AI API.""" @property def name(self) -> str: return 'fireworks' @property def base_url(self) -> str: return 'https://api.fireworks.ai/inference/v1' @property def client(self) -> AsyncOpenAI: return self._client def model_profile(self, model_name: str) -> ModelProfile | None: prefix_to_profile = { 'llama': meta_model_profile, 'qwen': qwen_model_profile, 'deepseek': deepseek_model_profile, 'mistral': mistral_model_profile, 'gemma': google_model_profile, } prefix = 'accounts/fireworks/models/' profile = None if model_name.startswith(prefix): model_name = model_name[len(prefix) :] for provider, profile_func in prefix_to_profile.items(): if model_name.startswith(provider): profile = profile_func(model_name) break # As the Fireworks API is OpenAI-compatible, let's assume we also need OpenAIJsonSchemaTransformer, # unless json_schema_transformer is set explicitly return OpenAIModelProfile(json_schema_transformer=OpenAIJsonSchemaTransformer).update(profile) @overload def __init__(self) -> None: ... @overload def __init__(self, *, api_key: str) -> None: ... @overload def __init__(self, *, api_key: str, http_client: AsyncHTTPClient) -> None: ... @overload def __init__(self, *, openai_client: AsyncOpenAI | None = None) -> None: ... def __init__( self, *, api_key: str | None = None, openai_client: AsyncOpenAI | None = None, http_client: AsyncHTTPClient | None = None, ) -> None: api_key = api_key or os.getenv('FIREWORKS_API_KEY') if not api_key and openai_client is None: raise UserError( 'Set the `FIREWORKS_API_KEY` environment variable or pass it via `FireworksProvider(api_key=...)`' 'to use the Fireworks AI provider.' ) if openai_client is not None: self._client = openai_client elif http_client is not None: self._client = AsyncOpenAI(base_url=self.base_url, api_key=api_key, http_client=http_client) else: http_client = cached_async_http_client(provider='fireworks') self._client = AsyncOpenAI(base_url=self.base_url, api_key=api_key, http_client=http_client) ``` Bases: `Provider[AsyncOpenAI]` Provider for Grok API. Source code in `pydantic_ai_slim/pydantic_ai/providers/grok.py` ```python class GrokProvider(Provider[AsyncOpenAI]): """Provider for Grok API.""" @property def name(self) -> str: return 'grok' @property def base_url(self) -> str: return 'https://api.x.ai/v1' @property def client(self) -> AsyncOpenAI: return self._client def model_profile(self, model_name: str) -> ModelProfile | None: profile = grok_model_profile(model_name) # As the Grok API is OpenAI-compatible, let's assume we also need OpenAIJsonSchemaTransformer, # unless json_schema_transformer is set explicitly. # Also, Grok does not support strict tool definitions: https://github.com/pydantic/pydantic-ai/issues/1846 return OpenAIModelProfile( json_schema_transformer=OpenAIJsonSchemaTransformer, openai_supports_strict_tool_definition=False ).update(profile) @overload def __init__(self) -> None: ... @overload def __init__(self, *, api_key: str) -> None: ... @overload def __init__(self, *, api_key: str, http_client: AsyncHTTPClient) -> None: ... @overload def __init__(self, *, openai_client: AsyncOpenAI | None = None) -> None: ... def __init__( self, *, api_key: str | None = None, openai_client: AsyncOpenAI | None = None, http_client: AsyncHTTPClient | None = None, ) -> None: api_key = api_key or os.getenv('GROK_API_KEY') if not api_key and openai_client is None: raise UserError( 'Set the `GROK_API_KEY` environment variable or pass it via `GrokProvider(api_key=...)`' 'to use the Grok provider.' ) if openai_client is not None: self._client = openai_client elif http_client is not None: self._client = AsyncOpenAI(base_url=self.base_url, api_key=api_key, http_client=http_client) else: http_client = cached_async_http_client(provider='grok') self._client = AsyncOpenAI(base_url=self.base_url, api_key=api_key, http_client=http_client) ``` Bases: `Provider[AsyncOpenAI]` Provider for Together AI API. Source code in `pydantic_ai_slim/pydantic_ai/providers/together.py` ```python class TogetherProvider(Provider[AsyncOpenAI]): """Provider for Together AI API.""" @property def name(self) -> str: return 'together' @property def base_url(self) -> str: return 'https://api.together.xyz/v1' @property def client(self) -> AsyncOpenAI: return self._client def model_profile(self, model_name: str) -> ModelProfile | None: provider_to_profile = { 'deepseek-ai': deepseek_model_profile, 'google': google_model_profile, 'qwen': qwen_model_profile, 'meta-llama': meta_model_profile, 'mistralai': mistral_model_profile, } profile = None model_name = model_name.lower() provider, model_name = model_name.split('/', 1) if provider in provider_to_profile: profile = provider_to_profile[provider](model_name) # As the Together API is OpenAI-compatible, let's assume we also need OpenAIJsonSchemaTransformer, # unless json_schema_transformer is set explicitly return OpenAIModelProfile(json_schema_transformer=OpenAIJsonSchemaTransformer).update(profile) @overload def __init__(self) -> None: ... @overload def __init__(self, *, api_key: str) -> None: ... @overload def __init__(self, *, api_key: str, http_client: AsyncHTTPClient) -> None: ... @overload def __init__(self, *, openai_client: AsyncOpenAI | None = None) -> None: ... def __init__( self, *, api_key: str | None = None, openai_client: AsyncOpenAI | None = None, http_client: AsyncHTTPClient | None = None, ) -> None: api_key = api_key or os.getenv('TOGETHER_API_KEY') if not api_key and openai_client is None: raise UserError( 'Set the `TOGETHER_API_KEY` environment variable or pass it via `TogetherProvider(api_key=...)`' 'to use the Together AI provider.' ) if openai_client is not None: self._client = openai_client elif http_client is not None: self._client = AsyncOpenAI(base_url=self.base_url, api_key=api_key, http_client=http_client) else: http_client = cached_async_http_client(provider='together') self._client = AsyncOpenAI(base_url=self.base_url, api_key=api_key, http_client=http_client) ``` Bases: `Provider[AsyncOpenAI]` Provider for Heroku API. Source code in `pydantic_ai_slim/pydantic_ai/providers/heroku.py` ```python class HerokuProvider(Provider[AsyncOpenAI]): """Provider for Heroku API.""" @property def name(self) -> str: return 'heroku' @property def base_url(self) -> str: return str(self.client.base_url) @property def client(self) -> AsyncOpenAI: return self._client def model_profile(self, model_name: str) -> ModelProfile | None: # As the Heroku API is OpenAI-compatible, let's assume we also need OpenAIJsonSchemaTransformer. return OpenAIModelProfile(json_schema_transformer=OpenAIJsonSchemaTransformer) @overload def __init__(self) -> None: ... @overload def __init__(self, *, api_key: str) -> None: ... @overload def __init__(self, *, api_key: str, http_client: AsyncHTTPClient) -> None: ... @overload def __init__(self, *, openai_client: AsyncOpenAI | None = None) -> None: ... def __init__( self, *, base_url: str | None = None, api_key: str | None = None, openai_client: AsyncOpenAI | None = None, http_client: AsyncHTTPClient | None = None, ) -> None: if openai_client is not None: assert http_client is None, 'Cannot provide both `openai_client` and `http_client`' assert api_key is None, 'Cannot provide both `openai_client` and `api_key`' self._client = openai_client else: api_key = api_key or os.environ.get('HEROKU_INFERENCE_KEY') if not api_key: raise UserError( 'Set the `HEROKU_INFERENCE_KEY` environment variable or pass it via `HerokuProvider(api_key=...)`' 'to use the Heroku provider.' ) base_url = base_url or os.environ.get('HEROKU_INFERENCE_URL', 'https://us.inference.heroku.com') base_url = base_url.rstrip('/') + '/v1' if http_client is not None: self._client = AsyncOpenAI(api_key=api_key, http_client=http_client, base_url=base_url) else: http_client = cached_async_http_client(provider='heroku') self._client = AsyncOpenAI(api_key=api_key, http_client=http_client, base_url=base_url) ``` Bases: `Provider[AsyncOpenAI]` Provider for GitHub Models API. GitHub Models provides access to various AI models through an OpenAI-compatible API. See for more information. Source code in `pydantic_ai_slim/pydantic_ai/providers/github.py` ```python class GitHubProvider(Provider[AsyncOpenAI]): """Provider for GitHub Models API. GitHub Models provides access to various AI models through an OpenAI-compatible API. See for more information. """ @property def name(self) -> str: return 'github' @property def base_url(self) -> str: return 'https://models.github.ai/inference' @property def client(self) -> AsyncOpenAI: return self._client def model_profile(self, model_name: str) -> ModelProfile | None: provider_to_profile = { 'xai': grok_model_profile, 'meta': meta_model_profile, 'microsoft': openai_model_profile, 'mistral-ai': mistral_model_profile, 'cohere': cohere_model_profile, 'deepseek': deepseek_model_profile, } profile = None # If the model name does not contain a provider prefix, we assume it's an OpenAI model if '/' not in model_name: return openai_model_profile(model_name) provider, model_name = model_name.lower().split('/', 1) if provider in provider_to_profile: model_name, *_ = model_name.split(':', 1) # drop tags profile = provider_to_profile[provider](model_name) # As GitHubProvider is always used with OpenAIModel, which used to unconditionally use OpenAIJsonSchemaTransformer, # we need to maintain that behavior unless json_schema_transformer is set explicitly return OpenAIModelProfile(json_schema_transformer=OpenAIJsonSchemaTransformer).update(profile) @overload def __init__(self) -> None: ... @overload def __init__(self, *, api_key: str) -> None: ... @overload def __init__(self, *, api_key: str, http_client: AsyncHTTPClient) -> None: ... @overload def __init__(self, *, openai_client: AsyncOpenAI | None = None) -> None: ... def __init__( self, *, api_key: str | None = None, openai_client: AsyncOpenAI | None = None, http_client: AsyncHTTPClient | None = None, ) -> None: """Create a new GitHub Models provider. Args: api_key: The GitHub token to use for authentication. If not provided, the `GITHUB_API_KEY` environment variable will be used if available. openai_client: An existing `AsyncOpenAI` client to use. If provided, `api_key` and `http_client` must be `None`. http_client: An existing `httpx.AsyncClient` to use for making HTTP requests. """ api_key = api_key or os.getenv('GITHUB_API_KEY') if not api_key and openai_client is None: raise UserError( 'Set the `GITHUB_API_KEY` environment variable or pass it via `GitHubProvider(api_key=...)`' ' to use the GitHub Models provider.' ) if openai_client is not None: self._client = openai_client elif http_client is not None: self._client = AsyncOpenAI(base_url=self.base_url, api_key=api_key, http_client=http_client) else: http_client = cached_async_http_client(provider='github') self._client = AsyncOpenAI(base_url=self.base_url, api_key=api_key, http_client=http_client) ``` ### __init__ ```python __init__() -> None ``` ```python __init__(*, api_key: str) -> None ``` ```python __init__(*, api_key: str, http_client: AsyncClient) -> None ``` ```python __init__( *, openai_client: AsyncOpenAI | None = None ) -> None ``` ```python __init__( *, api_key: str | None = None, openai_client: AsyncOpenAI | None = None, http_client: AsyncClient | None = None ) -> None ``` Create a new GitHub Models provider. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `api_key` | `str | None` | The GitHub token to use for authentication. If not provided, the GITHUB_API_KEY environment variable will be used if available. | `None` | | `openai_client` | `AsyncOpenAI | None` | An existing AsyncOpenAI client to use. If provided, api_key and http_client must be None. | `None` | | `http_client` | `AsyncClient | None` | An existing httpx.AsyncClient to use for making HTTP requests. | `None` | Source code in `pydantic_ai_slim/pydantic_ai/providers/github.py` ```python def __init__( self, *, api_key: str | None = None, openai_client: AsyncOpenAI | None = None, http_client: AsyncHTTPClient | None = None, ) -> None: """Create a new GitHub Models provider. Args: api_key: The GitHub token to use for authentication. If not provided, the `GITHUB_API_KEY` environment variable will be used if available. openai_client: An existing `AsyncOpenAI` client to use. If provided, `api_key` and `http_client` must be `None`. http_client: An existing `httpx.AsyncClient` to use for making HTTP requests. """ api_key = api_key or os.getenv('GITHUB_API_KEY') if not api_key and openai_client is None: raise UserError( 'Set the `GITHUB_API_KEY` environment variable or pass it via `GitHubProvider(api_key=...)`' ' to use the GitHub Models provider.' ) if openai_client is not None: self._client = openai_client elif http_client is not None: self._client = AsyncOpenAI(base_url=self.base_url, api_key=api_key, http_client=http_client) else: http_client = cached_async_http_client(provider='github') self._client = AsyncOpenAI(base_url=self.base_url, api_key=api_key, http_client=http_client) ``` Bases: `Provider[AsyncOpenAI]` Provider for OpenRouter API. Source code in `pydantic_ai_slim/pydantic_ai/providers/openrouter.py` ```python class OpenRouterProvider(Provider[AsyncOpenAI]): """Provider for OpenRouter API.""" @property def name(self) -> str: return 'openrouter' @property def base_url(self) -> str: return 'https://openrouter.ai/api/v1' @property def client(self) -> AsyncOpenAI: return self._client def model_profile(self, model_name: str) -> ModelProfile | None: provider_to_profile = { 'google': google_model_profile, 'openai': openai_model_profile, 'anthropic': anthropic_model_profile, 'mistralai': mistral_model_profile, 'qwen': qwen_model_profile, 'x-ai': grok_model_profile, 'cohere': cohere_model_profile, 'amazon': amazon_model_profile, 'deepseek': deepseek_model_profile, 'meta-llama': meta_model_profile, } profile = None provider, model_name = model_name.split('/', 1) if provider in provider_to_profile: model_name, *_ = model_name.split(':', 1) # drop tags profile = provider_to_profile[provider](model_name) # As OpenRouterProvider is always used with OpenAIModel, which used to unconditionally use OpenAIJsonSchemaTransformer, # we need to maintain that behavior unless json_schema_transformer is set explicitly return OpenAIModelProfile(json_schema_transformer=OpenAIJsonSchemaTransformer).update(profile) @overload def __init__(self) -> None: ... @overload def __init__(self, *, api_key: str) -> None: ... @overload def __init__(self, *, api_key: str, http_client: AsyncHTTPClient) -> None: ... @overload def __init__(self, *, openai_client: AsyncOpenAI | None = None) -> None: ... def __init__( self, *, api_key: str | None = None, openai_client: AsyncOpenAI | None = None, http_client: AsyncHTTPClient | None = None, ) -> None: api_key = api_key or os.getenv('OPENROUTER_API_KEY') if not api_key and openai_client is None: raise UserError( 'Set the `OPENROUTER_API_KEY` environment variable or pass it via `OpenRouterProvider(api_key=...)`' 'to use the OpenRouter provider.' ) if openai_client is not None: self._client = openai_client elif http_client is not None: self._client = AsyncOpenAI(base_url=self.base_url, api_key=api_key, http_client=http_client) else: http_client = cached_async_http_client(provider='openrouter') self._client = AsyncOpenAI(base_url=self.base_url, api_key=api_key, http_client=http_client) ``` # `pydantic_ai.result` ### StreamedRunResult Bases: `Generic[AgentDepsT, OutputDataT]` Result of a streamed run that returns structured data via a tool call. Source code in `pydantic_ai_slim/pydantic_ai/result.py` ```python @dataclass class StreamedRunResult(Generic[AgentDepsT, OutputDataT]): """Result of a streamed run that returns structured data via a tool call.""" _all_messages: list[_messages.ModelMessage] _new_message_index: int _usage_limits: UsageLimits | None _stream_response: models.StreamedResponse _output_schema: OutputSchema[OutputDataT] _run_ctx: RunContext[AgentDepsT] _output_validators: list[OutputValidator[AgentDepsT, OutputDataT]] _output_tool_name: str | None _on_complete: Callable[[], Awaitable[None]] _initial_run_ctx_usage: Usage = field(init=False) is_complete: bool = field(default=False, init=False) """Whether the stream has all been received. This is set to `True` when one of [`stream`][pydantic_ai.result.StreamedRunResult.stream], [`stream_text`][pydantic_ai.result.StreamedRunResult.stream_text], [`stream_structured`][pydantic_ai.result.StreamedRunResult.stream_structured] or [`get_output`][pydantic_ai.result.StreamedRunResult.get_output] completes. """ def __post_init__(self): self._initial_run_ctx_usage = copy(self._run_ctx.usage) @overload def all_messages(self, *, output_tool_return_content: str | None = None) -> list[_messages.ModelMessage]: ... @overload @deprecated('`result_tool_return_content` is deprecated, use `output_tool_return_content` instead.') def all_messages(self, *, result_tool_return_content: str | None = None) -> list[_messages.ModelMessage]: ... def all_messages( self, *, output_tool_return_content: str | None = None, result_tool_return_content: str | None = None ) -> list[_messages.ModelMessage]: """Return the history of _messages. Args: output_tool_return_content: The return content of the tool call to set in the last message. This provides a convenient way to modify the content of the output tool call if you want to continue the conversation and want to set the response to the output tool call. If `None`, the last message will not be modified. result_tool_return_content: deprecated, use `output_tool_return_content` instead. Returns: List of messages. """ # this is a method to be consistent with the other methods content = coalesce_deprecated_return_content(output_tool_return_content, result_tool_return_content) if content is not None: raise NotImplementedError('Setting output tool return content is not supported for this result type.') return self._all_messages @overload def all_messages_json(self, *, output_tool_return_content: str | None = None) -> bytes: ... @overload @deprecated('`result_tool_return_content` is deprecated, use `output_tool_return_content` instead.') def all_messages_json(self, *, result_tool_return_content: str | None = None) -> bytes: ... def all_messages_json( self, *, output_tool_return_content: str | None = None, result_tool_return_content: str | None = None ) -> bytes: # pragma: no cover """Return all messages from [`all_messages`][pydantic_ai.result.StreamedRunResult.all_messages] as JSON bytes. Args: output_tool_return_content: The return content of the tool call to set in the last message. This provides a convenient way to modify the content of the output tool call if you want to continue the conversation and want to set the response to the output tool call. If `None`, the last message will not be modified. result_tool_return_content: deprecated, use `output_tool_return_content` instead. Returns: JSON bytes representing the messages. """ content = coalesce_deprecated_return_content(output_tool_return_content, result_tool_return_content) return _messages.ModelMessagesTypeAdapter.dump_json(self.all_messages(output_tool_return_content=content)) @overload def new_messages(self, *, output_tool_return_content: str | None = None) -> list[_messages.ModelMessage]: ... @overload @deprecated('`result_tool_return_content` is deprecated, use `output_tool_return_content` instead.') def new_messages(self, *, output_tool_return_content: str | None = None) -> list[_messages.ModelMessage]: ... def new_messages( self, *, output_tool_return_content: str | None = None, result_tool_return_content: str | None = None ) -> list[_messages.ModelMessage]: # pragma: no cover """Return new messages associated with this run. Messages from older runs are excluded. Args: output_tool_return_content: The return content of the tool call to set in the last message. This provides a convenient way to modify the content of the output tool call if you want to continue the conversation and want to set the response to the output tool call. If `None`, the last message will not be modified. result_tool_return_content: deprecated, use `output_tool_return_content` instead. Returns: List of new messages. """ content = coalesce_deprecated_return_content(output_tool_return_content, result_tool_return_content) return self.all_messages(output_tool_return_content=content)[self._new_message_index :] @overload def new_messages_json(self, *, output_tool_return_content: str | None = None) -> bytes: ... @overload @deprecated('`result_tool_return_content` is deprecated, use `output_tool_return_content` instead.') def new_messages_json(self, *, result_tool_return_content: str | None = None) -> bytes: ... def new_messages_json( self, *, output_tool_return_content: str | None = None, result_tool_return_content: str | None = None ) -> bytes: # pragma: no cover """Return new messages from [`new_messages`][pydantic_ai.result.StreamedRunResult.new_messages] as JSON bytes. Args: output_tool_return_content: The return content of the tool call to set in the last message. This provides a convenient way to modify the content of the output tool call if you want to continue the conversation and want to set the response to the output tool call. If `None`, the last message will not be modified. result_tool_return_content: deprecated, use `output_tool_return_content` instead. Returns: JSON bytes representing the new messages. """ content = coalesce_deprecated_return_content(output_tool_return_content, result_tool_return_content) return _messages.ModelMessagesTypeAdapter.dump_json(self.new_messages(output_tool_return_content=content)) async def stream(self, *, debounce_by: float | None = 0.1) -> AsyncIterator[OutputDataT]: """Stream the response as an async iterable. The pydantic validator for structured data will be called in [partial mode](https://docs.pydantic.dev/dev/concepts/experimental/#partial-validation) on each iteration. Args: debounce_by: by how much (if at all) to debounce/group the response chunks by. `None` means no debouncing. Debouncing is particularly important for long structured responses to reduce the overhead of performing validation as each token is received. Returns: An async iterable of the response data. """ async for structured_message, is_last in self.stream_structured(debounce_by=debounce_by): try: yield await self.validate_structured_output(structured_message, allow_partial=not is_last) except ValidationError: if is_last: raise # pragma: lax no cover async def stream_text(self, *, delta: bool = False, debounce_by: float | None = 0.1) -> AsyncIterator[str]: """Stream the text result as an async iterable. !!! note Result validators will NOT be called on the text result if `delta=True`. Args: delta: if `True`, yield each chunk of text as it is received, if `False` (default), yield the full text up to the current point. debounce_by: by how much (if at all) to debounce/group the response chunks by. `None` means no debouncing. Debouncing is particularly important for long structured responses to reduce the overhead of performing validation as each token is received. """ if not isinstance(self._output_schema, PlainTextOutputSchema): raise exceptions.UserError('stream_text() can only be used with text responses') if delta: async for text in self._stream_response_text(delta=delta, debounce_by=debounce_by): yield text else: async for text in self._stream_response_text(delta=delta, debounce_by=debounce_by): combined_validated_text = await self._validate_text_output(text) yield combined_validated_text await self._marked_completed(self._stream_response.get()) async def stream_structured( self, *, debounce_by: float | None = 0.1 ) -> AsyncIterator[tuple[_messages.ModelResponse, bool]]: """Stream the response as an async iterable of Structured LLM Messages. Args: debounce_by: by how much (if at all) to debounce/group the response chunks by. `None` means no debouncing. Debouncing is particularly important for long structured responses to reduce the overhead of performing validation as each token is received. Returns: An async iterable of the structured response message and whether that is the last message. """ # if the message currently has any parts with content, yield before streaming msg = self._stream_response.get() for part in msg.parts: if part.has_content(): yield msg, False break async for msg in self._stream_response_structured(debounce_by=debounce_by): yield msg, False msg = self._stream_response.get() yield msg, True await self._marked_completed(msg) async def get_output(self) -> OutputDataT: """Stream the whole response, validate and return it.""" usage_checking_stream = _get_usage_checking_stream_response( self._stream_response, self._usage_limits, self.usage ) async for _ in usage_checking_stream: pass message = self._stream_response.get() await self._marked_completed(message) return await self.validate_structured_output(message) @deprecated('`get_data` is deprecated, use `get_output` instead.') async def get_data(self) -> OutputDataT: return await self.get_output() def usage(self) -> Usage: """Return the usage of the whole run. !!! note This won't return the full usage until the stream is finished. """ return self._initial_run_ctx_usage + self._stream_response.usage() def timestamp(self) -> datetime: """Get the timestamp of the response.""" return self._stream_response.timestamp @deprecated('`validate_structured_result` is deprecated, use `validate_structured_output` instead.') async def validate_structured_result( self, message: _messages.ModelResponse, *, allow_partial: bool = False ) -> OutputDataT: return await self.validate_structured_output(message, allow_partial=allow_partial) async def validate_structured_output( self, message: _messages.ModelResponse, *, allow_partial: bool = False ) -> OutputDataT: """Validate a structured result message.""" call = None if isinstance(self._output_schema, ToolOutputSchema) and self._output_tool_name is not None: match = self._output_schema.find_named_tool(message.parts, self._output_tool_name) if match is None: raise exceptions.UnexpectedModelBehavior( # pragma: no cover f'Invalid response, unable to find tool: {self._output_schema.tool_names()}' ) call, output_tool = match result_data = await output_tool.process( call, self._run_ctx, allow_partial=allow_partial, wrap_validation_errors=False ) elif isinstance(self._output_schema, TextOutputSchema): text = '\n\n'.join(x.content for x in message.parts if isinstance(x, _messages.TextPart)) result_data = await self._output_schema.process( text, self._run_ctx, allow_partial=allow_partial, wrap_validation_errors=False ) else: raise exceptions.UnexpectedModelBehavior( # pragma: no cover 'Invalid response, unable to process text output' ) for validator in self._output_validators: result_data = await validator.validate(result_data, call, self._run_ctx) # pragma: no cover return result_data async def _validate_text_output(self, text: str) -> str: for validator in self._output_validators: text = await validator.validate(text, None, self._run_ctx) # pragma: no cover return text async def _marked_completed(self, message: _messages.ModelResponse) -> None: self.is_complete = True self._all_messages.append(message) await self._on_complete() async def _stream_response_structured( self, *, debounce_by: float | None = 0.1 ) -> AsyncIterator[_messages.ModelResponse]: async with _utils.group_by_temporal(self._stream_response, debounce_by) as group_iter: async for _items in group_iter: yield self._stream_response.get() async def _stream_response_text( self, *, delta: bool = False, debounce_by: float | None = 0.1 ) -> AsyncIterator[str]: """Stream the response as an async iterable of text.""" # Define a "merged" version of the iterator that will yield items that have already been retrieved # and items that we receive while streaming. We define a dedicated async iterator for this so we can # pass the combined stream to the group_by_temporal function within `_stream_text_deltas` below. async def _stream_text_deltas_ungrouped() -> AsyncIterator[tuple[str, int]]: # yields tuples of (text_content, part_index) # we don't currently make use of the part_index, but in principle this may be useful # so we retain it here for now to make possible future refactors simpler msg = self._stream_response.get() for i, part in enumerate(msg.parts): if isinstance(part, _messages.TextPart) and part.content: yield part.content, i async for event in self._stream_response: if ( isinstance(event, _messages.PartStartEvent) and isinstance(event.part, _messages.TextPart) and event.part.content ): yield event.part.content, event.index # pragma: no cover elif ( # pragma: no branch isinstance(event, _messages.PartDeltaEvent) and isinstance(event.delta, _messages.TextPartDelta) and event.delta.content_delta ): yield event.delta.content_delta, event.index async def _stream_text_deltas() -> AsyncIterator[str]: async with _utils.group_by_temporal(_stream_text_deltas_ungrouped(), debounce_by) as group_iter: async for items in group_iter: # Note: we are currently just dropping the part index on the group here yield ''.join([content for content, _ in items]) if delta: async for text in _stream_text_deltas(): yield text else: # a quick benchmark shows it's faster to build up a string with concat when we're # yielding at each step deltas: list[str] = [] async for text in _stream_text_deltas(): deltas.append(text) yield ''.join(deltas) ``` #### is_complete ```python is_complete: bool = field(default=False, init=False) ``` Whether the stream has all been received. This is set to `True` when one of stream, stream_text, stream_structured or get_output completes. #### all_messages ```python all_messages( *, output_tool_return_content: str | None = None ) -> list[ModelMessage] ``` ```python all_messages( *, result_tool_return_content: str | None = None ) -> list[ModelMessage] ``` ```python all_messages( *, output_tool_return_content: str | None = None, result_tool_return_content: str | None = None ) -> list[ModelMessage] ``` Return the history of \_messages. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `output_tool_return_content` | `str | None` | The return content of the tool call to set in the last message. This provides a convenient way to modify the content of the output tool call if you want to continue the conversation and want to set the response to the output tool call. If None, the last message will not be modified. | `None` | | `result_tool_return_content` | `str | None` | deprecated, use output_tool_return_content instead. | `None` | Returns: | Type | Description | | --- | --- | | `list[ModelMessage]` | List of messages. | Source code in `pydantic_ai_slim/pydantic_ai/result.py` ```python def all_messages( self, *, output_tool_return_content: str | None = None, result_tool_return_content: str | None = None ) -> list[_messages.ModelMessage]: """Return the history of _messages. Args: output_tool_return_content: The return content of the tool call to set in the last message. This provides a convenient way to modify the content of the output tool call if you want to continue the conversation and want to set the response to the output tool call. If `None`, the last message will not be modified. result_tool_return_content: deprecated, use `output_tool_return_content` instead. Returns: List of messages. """ # this is a method to be consistent with the other methods content = coalesce_deprecated_return_content(output_tool_return_content, result_tool_return_content) if content is not None: raise NotImplementedError('Setting output tool return content is not supported for this result type.') return self._all_messages ``` #### all_messages_json ```python all_messages_json( *, output_tool_return_content: str | None = None ) -> bytes ``` ```python all_messages_json( *, result_tool_return_content: str | None = None ) -> bytes ``` ```python all_messages_json( *, output_tool_return_content: str | None = None, result_tool_return_content: str | None = None ) -> bytes ``` Return all messages from all_messages as JSON bytes. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `output_tool_return_content` | `str | None` | The return content of the tool call to set in the last message. This provides a convenient way to modify the content of the output tool call if you want to continue the conversation and want to set the response to the output tool call. If None, the last message will not be modified. | `None` | | `result_tool_return_content` | `str | None` | deprecated, use output_tool_return_content instead. | `None` | Returns: | Type | Description | | --- | --- | | `bytes` | JSON bytes representing the messages. | Source code in `pydantic_ai_slim/pydantic_ai/result.py` ```python def all_messages_json( self, *, output_tool_return_content: str | None = None, result_tool_return_content: str | None = None ) -> bytes: # pragma: no cover """Return all messages from [`all_messages`][pydantic_ai.result.StreamedRunResult.all_messages] as JSON bytes. Args: output_tool_return_content: The return content of the tool call to set in the last message. This provides a convenient way to modify the content of the output tool call if you want to continue the conversation and want to set the response to the output tool call. If `None`, the last message will not be modified. result_tool_return_content: deprecated, use `output_tool_return_content` instead. Returns: JSON bytes representing the messages. """ content = coalesce_deprecated_return_content(output_tool_return_content, result_tool_return_content) return _messages.ModelMessagesTypeAdapter.dump_json(self.all_messages(output_tool_return_content=content)) ``` #### new_messages ```python new_messages( *, output_tool_return_content: str | None = None ) -> list[ModelMessage] ``` ```python new_messages( *, output_tool_return_content: str | None = None ) -> list[ModelMessage] ``` ```python new_messages( *, output_tool_return_content: str | None = None, result_tool_return_content: str | None = None ) -> list[ModelMessage] ``` Return new messages associated with this run. Messages from older runs are excluded. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `output_tool_return_content` | `str | None` | The return content of the tool call to set in the last message. This provides a convenient way to modify the content of the output tool call if you want to continue the conversation and want to set the response to the output tool call. If None, the last message will not be modified. | `None` | | `result_tool_return_content` | `str | None` | deprecated, use output_tool_return_content instead. | `None` | Returns: | Type | Description | | --- | --- | | `list[ModelMessage]` | List of new messages. | Source code in `pydantic_ai_slim/pydantic_ai/result.py` ```python def new_messages( self, *, output_tool_return_content: str | None = None, result_tool_return_content: str | None = None ) -> list[_messages.ModelMessage]: # pragma: no cover """Return new messages associated with this run. Messages from older runs are excluded. Args: output_tool_return_content: The return content of the tool call to set in the last message. This provides a convenient way to modify the content of the output tool call if you want to continue the conversation and want to set the response to the output tool call. If `None`, the last message will not be modified. result_tool_return_content: deprecated, use `output_tool_return_content` instead. Returns: List of new messages. """ content = coalesce_deprecated_return_content(output_tool_return_content, result_tool_return_content) return self.all_messages(output_tool_return_content=content)[self._new_message_index :] ``` #### new_messages_json ```python new_messages_json( *, output_tool_return_content: str | None = None ) -> bytes ``` ```python new_messages_json( *, result_tool_return_content: str | None = None ) -> bytes ``` ```python new_messages_json( *, output_tool_return_content: str | None = None, result_tool_return_content: str | None = None ) -> bytes ``` Return new messages from new_messages as JSON bytes. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `output_tool_return_content` | `str | None` | The return content of the tool call to set in the last message. This provides a convenient way to modify the content of the output tool call if you want to continue the conversation and want to set the response to the output tool call. If None, the last message will not be modified. | `None` | | `result_tool_return_content` | `str | None` | deprecated, use output_tool_return_content instead. | `None` | Returns: | Type | Description | | --- | --- | | `bytes` | JSON bytes representing the new messages. | Source code in `pydantic_ai_slim/pydantic_ai/result.py` ```python def new_messages_json( self, *, output_tool_return_content: str | None = None, result_tool_return_content: str | None = None ) -> bytes: # pragma: no cover """Return new messages from [`new_messages`][pydantic_ai.result.StreamedRunResult.new_messages] as JSON bytes. Args: output_tool_return_content: The return content of the tool call to set in the last message. This provides a convenient way to modify the content of the output tool call if you want to continue the conversation and want to set the response to the output tool call. If `None`, the last message will not be modified. result_tool_return_content: deprecated, use `output_tool_return_content` instead. Returns: JSON bytes representing the new messages. """ content = coalesce_deprecated_return_content(output_tool_return_content, result_tool_return_content) return _messages.ModelMessagesTypeAdapter.dump_json(self.new_messages(output_tool_return_content=content)) ``` #### stream ```python stream( *, debounce_by: float | None = 0.1 ) -> AsyncIterator[OutputDataT] ``` Stream the response as an async iterable. The pydantic validator for structured data will be called in [partial mode](https://docs.pydantic.dev/dev/concepts/experimental/#partial-validation) on each iteration. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `debounce_by` | `float | None` | by how much (if at all) to debounce/group the response chunks by. None means no debouncing. Debouncing is particularly important for long structured responses to reduce the overhead of performing validation as each token is received. | `0.1` | Returns: | Type | Description | | --- | --- | | `AsyncIterator[OutputDataT]` | An async iterable of the response data. | Source code in `pydantic_ai_slim/pydantic_ai/result.py` ```python async def stream(self, *, debounce_by: float | None = 0.1) -> AsyncIterator[OutputDataT]: """Stream the response as an async iterable. The pydantic validator for structured data will be called in [partial mode](https://docs.pydantic.dev/dev/concepts/experimental/#partial-validation) on each iteration. Args: debounce_by: by how much (if at all) to debounce/group the response chunks by. `None` means no debouncing. Debouncing is particularly important for long structured responses to reduce the overhead of performing validation as each token is received. Returns: An async iterable of the response data. """ async for structured_message, is_last in self.stream_structured(debounce_by=debounce_by): try: yield await self.validate_structured_output(structured_message, allow_partial=not is_last) except ValidationError: if is_last: raise # pragma: lax no cover ``` #### stream_text ```python stream_text( *, delta: bool = False, debounce_by: float | None = 0.1 ) -> AsyncIterator[str] ``` Stream the text result as an async iterable. Note Result validators will NOT be called on the text result if `delta=True`. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `delta` | `bool` | if True, yield each chunk of text as it is received, if False (default), yield the full text up to the current point. | `False` | | `debounce_by` | `float | None` | by how much (if at all) to debounce/group the response chunks by. None means no debouncing. Debouncing is particularly important for long structured responses to reduce the overhead of performing validation as each token is received. | `0.1` | Source code in `pydantic_ai_slim/pydantic_ai/result.py` ```python async def stream_text(self, *, delta: bool = False, debounce_by: float | None = 0.1) -> AsyncIterator[str]: """Stream the text result as an async iterable. !!! note Result validators will NOT be called on the text result if `delta=True`. Args: delta: if `True`, yield each chunk of text as it is received, if `False` (default), yield the full text up to the current point. debounce_by: by how much (if at all) to debounce/group the response chunks by. `None` means no debouncing. Debouncing is particularly important for long structured responses to reduce the overhead of performing validation as each token is received. """ if not isinstance(self._output_schema, PlainTextOutputSchema): raise exceptions.UserError('stream_text() can only be used with text responses') if delta: async for text in self._stream_response_text(delta=delta, debounce_by=debounce_by): yield text else: async for text in self._stream_response_text(delta=delta, debounce_by=debounce_by): combined_validated_text = await self._validate_text_output(text) yield combined_validated_text await self._marked_completed(self._stream_response.get()) ``` #### stream_structured ```python stream_structured( *, debounce_by: float | None = 0.1 ) -> AsyncIterator[tuple[ModelResponse, bool]] ``` Stream the response as an async iterable of Structured LLM Messages. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `debounce_by` | `float | None` | by how much (if at all) to debounce/group the response chunks by. None means no debouncing. Debouncing is particularly important for long structured responses to reduce the overhead of performing validation as each token is received. | `0.1` | Returns: | Type | Description | | --- | --- | | `AsyncIterator[tuple[ModelResponse, bool]]` | An async iterable of the structured response message and whether that is the last message. | Source code in `pydantic_ai_slim/pydantic_ai/result.py` ```python async def stream_structured( self, *, debounce_by: float | None = 0.1 ) -> AsyncIterator[tuple[_messages.ModelResponse, bool]]: """Stream the response as an async iterable of Structured LLM Messages. Args: debounce_by: by how much (if at all) to debounce/group the response chunks by. `None` means no debouncing. Debouncing is particularly important for long structured responses to reduce the overhead of performing validation as each token is received. Returns: An async iterable of the structured response message and whether that is the last message. """ # if the message currently has any parts with content, yield before streaming msg = self._stream_response.get() for part in msg.parts: if part.has_content(): yield msg, False break async for msg in self._stream_response_structured(debounce_by=debounce_by): yield msg, False msg = self._stream_response.get() yield msg, True await self._marked_completed(msg) ``` #### get_output ```python get_output() -> OutputDataT ``` Stream the whole response, validate and return it. Source code in `pydantic_ai_slim/pydantic_ai/result.py` ```python async def get_output(self) -> OutputDataT: """Stream the whole response, validate and return it.""" usage_checking_stream = _get_usage_checking_stream_response( self._stream_response, self._usage_limits, self.usage ) async for _ in usage_checking_stream: pass message = self._stream_response.get() await self._marked_completed(message) return await self.validate_structured_output(message) ``` #### usage ```python usage() -> Usage ``` Return the usage of the whole run. Note This won't return the full usage until the stream is finished. Source code in `pydantic_ai_slim/pydantic_ai/result.py` ```python def usage(self) -> Usage: """Return the usage of the whole run. !!! note This won't return the full usage until the stream is finished. """ return self._initial_run_ctx_usage + self._stream_response.usage() ``` #### timestamp ```python timestamp() -> datetime ``` Get the timestamp of the response. Source code in `pydantic_ai_slim/pydantic_ai/result.py` ```python def timestamp(self) -> datetime: """Get the timestamp of the response.""" return self._stream_response.timestamp ``` #### validate_structured_output ```python validate_structured_output( message: ModelResponse, *, allow_partial: bool = False ) -> OutputDataT ``` Validate a structured result message. Source code in `pydantic_ai_slim/pydantic_ai/result.py` ```python async def validate_structured_output( self, message: _messages.ModelResponse, *, allow_partial: bool = False ) -> OutputDataT: """Validate a structured result message.""" call = None if isinstance(self._output_schema, ToolOutputSchema) and self._output_tool_name is not None: match = self._output_schema.find_named_tool(message.parts, self._output_tool_name) if match is None: raise exceptions.UnexpectedModelBehavior( # pragma: no cover f'Invalid response, unable to find tool: {self._output_schema.tool_names()}' ) call, output_tool = match result_data = await output_tool.process( call, self._run_ctx, allow_partial=allow_partial, wrap_validation_errors=False ) elif isinstance(self._output_schema, TextOutputSchema): text = '\n\n'.join(x.content for x in message.parts if isinstance(x, _messages.TextPart)) result_data = await self._output_schema.process( text, self._run_ctx, allow_partial=allow_partial, wrap_validation_errors=False ) else: raise exceptions.UnexpectedModelBehavior( # pragma: no cover 'Invalid response, unable to process text output' ) for validator in self._output_validators: result_data = await validator.validate(result_data, call, self._run_ctx) # pragma: no cover return result_data ``` # `pydantic_ai.settings` ### ModelSettings Bases: `TypedDict` Settings to configure an LLM. Here we include only settings which apply to multiple models / model providers, though not all of these settings are supported by all models. Source code in `pydantic_ai_slim/pydantic_ai/settings.py` ```python class ModelSettings(TypedDict, total=False): """Settings to configure an LLM. Here we include only settings which apply to multiple models / model providers, though not all of these settings are supported by all models. """ max_tokens: int """The maximum number of tokens to generate before stopping. Supported by: * Gemini * Anthropic * OpenAI * Groq * Cohere * Mistral * Bedrock * MCP Sampling """ temperature: float """Amount of randomness injected into the response. Use `temperature` closer to `0.0` for analytical / multiple choice, and closer to a model's maximum `temperature` for creative and generative tasks. Note that even with `temperature` of `0.0`, the results will not be fully deterministic. Supported by: * Gemini * Anthropic * OpenAI * Groq * Cohere * Mistral * Bedrock """ top_p: float """An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered. You should either alter `temperature` or `top_p`, but not both. Supported by: * Gemini * Anthropic * OpenAI * Groq * Cohere * Mistral * Bedrock """ timeout: float | Timeout """Override the client-level default timeout for a request, in seconds. Supported by: * Gemini * Anthropic * OpenAI * Groq * Mistral """ parallel_tool_calls: bool """Whether to allow parallel tool calls. Supported by: * OpenAI (some models, not o1) * Groq * Anthropic """ seed: int """The random seed to use for the model, theoretically allowing for deterministic results. Supported by: * OpenAI * Groq * Cohere * Mistral """ presence_penalty: float """Penalize new tokens based on whether they have appeared in the text so far. Supported by: * OpenAI * Groq * Cohere * Gemini * Mistral """ frequency_penalty: float """Penalize new tokens based on their existing frequency in the text so far. Supported by: * OpenAI * Groq * Cohere * Gemini * Mistral """ logit_bias: dict[str, int] """Modify the likelihood of specified tokens appearing in the completion. Supported by: * OpenAI * Groq """ stop_sequences: list[str] """Sequences that will cause the model to stop generating. Supported by: * OpenAI * Anthropic * Bedrock * Mistral * Groq * Cohere * Google """ extra_headers: dict[str, str] """Extra headers to send to the model. Supported by: * OpenAI * Anthropic * Groq """ extra_body: object """Extra body to send to the model. Supported by: * OpenAI * Anthropic * Groq """ ``` #### max_tokens ```python max_tokens: int ``` The maximum number of tokens to generate before stopping. Supported by: - Gemini - Anthropic - OpenAI - Groq - Cohere - Mistral - Bedrock - MCP Sampling #### temperature ```python temperature: float ``` Amount of randomness injected into the response. Use `temperature` closer to `0.0` for analytical / multiple choice, and closer to a model's maximum `temperature` for creative and generative tasks. Note that even with `temperature` of `0.0`, the results will not be fully deterministic. Supported by: - Gemini - Anthropic - OpenAI - Groq - Cohere - Mistral - Bedrock #### top_p ```python top_p: float ``` An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered. You should either alter `temperature` or `top_p`, but not both. Supported by: - Gemini - Anthropic - OpenAI - Groq - Cohere - Mistral - Bedrock #### timeout ```python timeout: float | Timeout ``` Override the client-level default timeout for a request, in seconds. Supported by: - Gemini - Anthropic - OpenAI - Groq - Mistral #### parallel_tool_calls ```python parallel_tool_calls: bool ``` Whether to allow parallel tool calls. Supported by: - OpenAI (some models, not o1) - Groq - Anthropic #### seed ```python seed: int ``` The random seed to use for the model, theoretically allowing for deterministic results. Supported by: - OpenAI - Groq - Cohere - Mistral #### presence_penalty ```python presence_penalty: float ``` Penalize new tokens based on whether they have appeared in the text so far. Supported by: - OpenAI - Groq - Cohere - Gemini - Mistral #### frequency_penalty ```python frequency_penalty: float ``` Penalize new tokens based on their existing frequency in the text so far. Supported by: - OpenAI - Groq - Cohere - Gemini - Mistral #### logit_bias ```python logit_bias: dict[str, int] ``` Modify the likelihood of specified tokens appearing in the completion. Supported by: - OpenAI - Groq #### stop_sequences ```python stop_sequences: list[str] ``` Sequences that will cause the model to stop generating. Supported by: - OpenAI - Anthropic - Bedrock - Mistral - Groq - Cohere - Google #### extra_headers ```python extra_headers: dict[str, str] ``` Extra headers to send to the model. Supported by: - OpenAI - Anthropic - Groq #### extra_body ```python extra_body: object ``` Extra body to send to the model. Supported by: - OpenAI - Anthropic - Groq # `pydantic_ai.tools` ### AgentDepsT ```python AgentDepsT = TypeVar( "AgentDepsT", default=None, contravariant=True ) ``` Type variable for agent dependencies. ### RunContext Bases: `Generic[AgentDepsT]` Information about the current call. Source code in `pydantic_ai_slim/pydantic_ai/_run_context.py` ```python @dataclasses.dataclass(repr=False) class RunContext(Generic[AgentDepsT]): """Information about the current call.""" deps: AgentDepsT """Dependencies for the agent.""" model: Model """The model used in this run.""" usage: Usage """LLM usage associated with the run.""" prompt: str | Sequence[_messages.UserContent] | None """The original user prompt passed to the run.""" messages: list[_messages.ModelMessage] = field(default_factory=list) """Messages exchanged in the conversation so far.""" tool_call_id: str | None = None """The ID of the tool call.""" tool_name: str | None = None """Name of the tool being called.""" retry: int = 0 """Number of retries so far.""" run_step: int = 0 """The current step in the run.""" def replace_with( self, retry: int | None = None, tool_name: str | None | _utils.Unset = _utils.UNSET, ) -> RunContext[AgentDepsT]: # Create a new `RunContext` a new `retry` value and `tool_name`. kwargs = {} if retry is not None: kwargs['retry'] = retry if tool_name is not _utils.UNSET: # pragma: no branch kwargs['tool_name'] = tool_name return dataclasses.replace(self, **kwargs) __repr__ = _utils.dataclasses_no_defaults_repr ``` #### deps ```python deps: AgentDepsT ``` Dependencies for the agent. #### model ```python model: Model ``` The model used in this run. #### usage ```python usage: Usage ``` LLM usage associated with the run. #### prompt ```python prompt: str | Sequence[UserContent] | None ``` The original user prompt passed to the run. #### messages ```python messages: list[ModelMessage] = field(default_factory=list) ``` Messages exchanged in the conversation so far. #### tool_call_id ```python tool_call_id: str | None = None ``` The ID of the tool call. #### tool_name ```python tool_name: str | None = None ``` Name of the tool being called. #### retry ```python retry: int = 0 ``` Number of retries so far. #### run_step ```python run_step: int = 0 ``` The current step in the run. ### ToolParams ```python ToolParams = ParamSpec('ToolParams', default=...) ``` Retrieval function param spec. ### SystemPromptFunc ```python SystemPromptFunc = Union[ Callable[[RunContext[AgentDepsT]], str], Callable[[RunContext[AgentDepsT]], Awaitable[str]], Callable[[], str], Callable[[], Awaitable[str]], ] ``` A function that may or maybe not take `RunContext` as an argument, and may or may not be async. Usage `SystemPromptFunc[AgentDepsT]`. ### ToolFuncContext ```python ToolFuncContext = Callable[ Concatenate[RunContext[AgentDepsT], ToolParams], Any ] ``` A tool function that takes `RunContext` as the first argument. Usage `ToolContextFunc[AgentDepsT, ToolParams]`. ### ToolFuncPlain ```python ToolFuncPlain = Callable[ToolParams, Any] ``` A tool function that does not take `RunContext` as the first argument. Usage `ToolPlainFunc[ToolParams]`. ### ToolFuncEither ```python ToolFuncEither = Union[ ToolFuncContext[AgentDepsT, ToolParams], ToolFuncPlain[ToolParams], ] ``` Either kind of tool function. This is just a union of ToolFuncContext and ToolFuncPlain. Usage `ToolFuncEither[AgentDepsT, ToolParams]`. ### ToolPrepareFunc ```python ToolPrepareFunc: TypeAlias = ( "Callable[[RunContext[AgentDepsT], ToolDefinition], Awaitable[ToolDefinition | None]]" ) ``` Definition of a function that can prepare a tool definition at call time. See [tool docs](../../tools/#tool-prepare) for more information. Example — here `only_if_42` is valid as a `ToolPrepareFunc`: ```python from typing import Union from pydantic_ai import RunContext, Tool from pydantic_ai.tools import ToolDefinition async def only_if_42( ctx: RunContext[int], tool_def: ToolDefinition ) -> Union[ToolDefinition, None]: if ctx.deps == 42: return tool_def def hitchhiker(ctx: RunContext[int], answer: str) -> str: return f'{ctx.deps} {answer}' hitchhiker = Tool(hitchhiker, prepare=only_if_42) ``` Usage `ToolPrepareFunc[AgentDepsT]`. ### ToolsPrepareFunc ```python ToolsPrepareFunc: TypeAlias = ( "Callable[[RunContext[AgentDepsT], list[ToolDefinition]], Awaitable[list[ToolDefinition] | None]]" ) ``` Definition of a function that can prepare the tool definition of all tools for each step. This is useful if you want to customize the definition of multiple tools or you want to register a subset of tools for a given step. Example — here `turn_on_strict_if_openai` is valid as a `ToolsPrepareFunc`: ```python from dataclasses import replace from typing import Union from pydantic_ai import Agent, RunContext from pydantic_ai.tools import ToolDefinition async def turn_on_strict_if_openai( ctx: RunContext[None], tool_defs: list[ToolDefinition] ) -> Union[list[ToolDefinition], None]: if ctx.model.system == 'openai': return [replace(tool_def, strict=True) for tool_def in tool_defs] return tool_defs agent = Agent('openai:gpt-4o', prepare_tools=turn_on_strict_if_openai) ``` Usage `ToolsPrepareFunc[AgentDepsT]`. ### DocstringFormat ```python DocstringFormat = Literal[ "google", "numpy", "sphinx", "auto" ] ``` Supported docstring formats. - `'google'` — [Google-style](https://google.github.io/styleguide/pyguide.html#381-docstrings) docstrings. - `'numpy'` — [Numpy-style](https://numpydoc.readthedocs.io/en/latest/format.html) docstrings. - `'sphinx'` — [Sphinx-style](https://sphinx-rtd-tutorial.readthedocs.io/en/latest/docstrings.html#the-sphinx-docstring-format) docstrings. - `'auto'` — Automatically infer the format based on the structure of the docstring. ### Tool Bases: `Generic[AgentDepsT]` A tool function for an agent. Source code in `pydantic_ai_slim/pydantic_ai/tools.py` ````python @dataclass(init=False) class Tool(Generic[AgentDepsT]): """A tool function for an agent.""" function: ToolFuncEither[AgentDepsT] takes_ctx: bool max_retries: int | None name: str description: str prepare: ToolPrepareFunc[AgentDepsT] | None docstring_format: DocstringFormat require_parameter_descriptions: bool strict: bool | None function_schema: _function_schema.FunctionSchema """ The base JSON schema for the tool's parameters. This schema may be modified by the `prepare` function or by the Model class prior to including it in an API request. """ # TODO: Consider moving this current_retry state to live on something other than the tool. # We've worked around this for now by copying instances of the tool when creating new runs, # but this is a bit fragile. Moving the tool retry counts to live on the agent run state would likely clean things # up, though is also likely a larger effort to refactor. current_retry: int = field(default=0, init=False) def __init__( self, function: ToolFuncEither[AgentDepsT], *, takes_ctx: bool | None = None, max_retries: int | None = None, name: str | None = None, description: str | None = None, prepare: ToolPrepareFunc[AgentDepsT] | None = None, docstring_format: DocstringFormat = 'auto', require_parameter_descriptions: bool = False, schema_generator: type[GenerateJsonSchema] = GenerateToolJsonSchema, strict: bool | None = None, function_schema: _function_schema.FunctionSchema | None = None, ): """Create a new tool instance. Example usage: ```python {noqa="I001"} from pydantic_ai import Agent, RunContext, Tool async def my_tool(ctx: RunContext[int], x: int, y: int) -> str: return f'{ctx.deps} {x} {y}' agent = Agent('test', tools=[Tool(my_tool)]) ``` or with a custom prepare method: ```python {noqa="I001"} from typing import Union from pydantic_ai import Agent, RunContext, Tool from pydantic_ai.tools import ToolDefinition async def my_tool(ctx: RunContext[int], x: int, y: int) -> str: return f'{ctx.deps} {x} {y}' async def prep_my_tool( ctx: RunContext[int], tool_def: ToolDefinition ) -> Union[ToolDefinition, None]: # only register the tool if `deps == 42` if ctx.deps == 42: return tool_def agent = Agent('test', tools=[Tool(my_tool, prepare=prep_my_tool)]) ``` Args: function: The Python function to call as the tool. takes_ctx: Whether the function takes a [`RunContext`][pydantic_ai.tools.RunContext] first argument, this is inferred if unset. max_retries: Maximum number of retries allowed for this tool, set to the agent default if `None`. name: Name of the tool, inferred from the function if `None`. description: Description of the tool, inferred from the function if `None`. prepare: custom method to prepare the tool definition for each step, return `None` to omit this tool from a given step. This is useful if you want to customise a tool at call time, or omit it completely from a step. See [`ToolPrepareFunc`][pydantic_ai.tools.ToolPrepareFunc]. docstring_format: The format of the docstring, see [`DocstringFormat`][pydantic_ai.tools.DocstringFormat]. Defaults to `'auto'`, such that the format is inferred from the structure of the docstring. require_parameter_descriptions: If True, raise an error if a parameter description is missing. Defaults to False. schema_generator: The JSON schema generator class to use. Defaults to `GenerateToolJsonSchema`. strict: Whether to enforce JSON schema compliance (only affects OpenAI). See [`ToolDefinition`][pydantic_ai.tools.ToolDefinition] for more info. function_schema: The function schema to use for the tool. If not provided, it will be generated. """ self.function = function self.function_schema = function_schema or _function_schema.function_schema( function, schema_generator, takes_ctx=takes_ctx, docstring_format=docstring_format, require_parameter_descriptions=require_parameter_descriptions, ) self.takes_ctx = self.function_schema.takes_ctx self.max_retries = max_retries self.name = name or function.__name__ self.description = description or self.function_schema.description self.prepare = prepare self.docstring_format = docstring_format self.require_parameter_descriptions = require_parameter_descriptions self.strict = strict @classmethod def from_schema( cls, function: Callable[..., Any], name: str, description: str, json_schema: JsonSchemaValue, ) -> Self: """Creates a Pydantic tool from a function and a JSON schema. Args: function: The function to call. This will be called with keywords only, and no validation of the arguments will be performed. name: The unique name of the tool that clearly communicates its purpose description: Used to tell the model how/when/why to use the tool. You can provide few-shot examples as a part of the description. json_schema: The schema for the function arguments Returns: A Pydantic tool that calls the function """ function_schema = _function_schema.FunctionSchema( function=function, description=description, validator=SchemaValidator(schema=core_schema.any_schema()), json_schema=json_schema, takes_ctx=False, is_async=_utils.is_async_callable(function), ) return cls( function, takes_ctx=False, name=name, description=description, function_schema=function_schema, ) async def prepare_tool_def(self, ctx: RunContext[AgentDepsT]) -> ToolDefinition | None: """Get the tool definition. By default, this method creates a tool definition, then either returns it, or calls `self.prepare` if it's set. Returns: return a `ToolDefinition` or `None` if the tools should not be registered for this run. """ tool_def = ToolDefinition( name=self.name, description=self.description, parameters_json_schema=self.function_schema.json_schema, strict=self.strict, ) if self.prepare is not None: return await self.prepare(ctx, tool_def) else: return tool_def async def run( self, message: _messages.ToolCallPart, run_context: RunContext[AgentDepsT], tracer: Tracer, include_content: bool = False, ) -> _messages.ToolReturnPart | _messages.RetryPromptPart: """Run the tool function asynchronously. This method wraps `_run` in an OpenTelemetry span. See . """ span_attributes = { 'gen_ai.tool.name': self.name, # NOTE: this means `gen_ai.tool.call.id` will be included even if it was generated by pydantic-ai 'gen_ai.tool.call.id': message.tool_call_id, **({'tool_arguments': message.args_as_json_str()} if include_content else {}), 'logfire.msg': f'running tool: {self.name}', # add the JSON schema so these attributes are formatted nicely in Logfire 'logfire.json_schema': json.dumps( { 'type': 'object', 'properties': { **( { 'tool_arguments': {'type': 'object'}, 'tool_response': {'type': 'object'}, } if include_content else {} ), 'gen_ai.tool.name': {}, 'gen_ai.tool.call.id': {}, }, } ), } with tracer.start_as_current_span('running tool', attributes=span_attributes) as span: response = await self._run(message, run_context) if include_content and span.is_recording(): span.set_attribute( 'tool_response', response.model_response_str() if isinstance(response, ToolReturnPart) else response.model_response(), ) return response async def _run( self, message: _messages.ToolCallPart, run_context: RunContext[AgentDepsT] ) -> _messages.ToolReturnPart | _messages.RetryPromptPart: try: validator = self.function_schema.validator if isinstance(message.args, str): args_dict = validator.validate_json(message.args or '{}') else: args_dict = validator.validate_python(message.args or {}) except ValidationError as e: return self._on_error(e, message) ctx = dataclasses.replace( run_context, retry=self.current_retry, tool_name=message.tool_name, tool_call_id=message.tool_call_id, ) try: response_content = await self.function_schema.call(args_dict, ctx) except ModelRetry as e: return self._on_error(e, message) self.current_retry = 0 return _messages.ToolReturnPart( tool_name=message.tool_name, content=response_content, tool_call_id=message.tool_call_id, ) def _on_error( self, exc: ValidationError | ModelRetry, call_message: _messages.ToolCallPart ) -> _messages.RetryPromptPart: self.current_retry += 1 if self.max_retries is None or self.current_retry > self.max_retries: raise UnexpectedModelBehavior(f'Tool exceeded max retries count of {self.max_retries}') from exc else: if isinstance(exc, ValidationError): content = exc.errors(include_url=False, include_context=False) else: content = exc.message return _messages.RetryPromptPart( tool_name=call_message.tool_name, content=content, tool_call_id=call_message.tool_call_id, ) ```` #### __init__ ```python __init__( function: ToolFuncEither[AgentDepsT], *, takes_ctx: bool | None = None, max_retries: int | None = None, name: str | None = None, description: str | None = None, prepare: ToolPrepareFunc[AgentDepsT] | None = None, docstring_format: DocstringFormat = "auto", require_parameter_descriptions: bool = False, schema_generator: type[ GenerateJsonSchema ] = GenerateToolJsonSchema, strict: bool | None = None, function_schema: FunctionSchema | None = None ) ``` Create a new tool instance. Example usage: ```python from pydantic_ai import Agent, RunContext, Tool async def my_tool(ctx: RunContext[int], x: int, y: int) -> str: return f'{ctx.deps} {x} {y}' agent = Agent('test', tools=[Tool(my_tool)]) ``` or with a custom prepare method: ```python from typing import Union from pydantic_ai import Agent, RunContext, Tool from pydantic_ai.tools import ToolDefinition async def my_tool(ctx: RunContext[int], x: int, y: int) -> str: return f'{ctx.deps} {x} {y}' async def prep_my_tool( ctx: RunContext[int], tool_def: ToolDefinition ) -> Union[ToolDefinition, None]: # only register the tool if `deps == 42` if ctx.deps == 42: return tool_def agent = Agent('test', tools=[Tool(my_tool, prepare=prep_my_tool)]) ``` Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `function` | `ToolFuncEither[AgentDepsT]` | The Python function to call as the tool. | *required* | | `takes_ctx` | `bool | None` | Whether the function takes a RunContext first argument, this is inferred if unset. | `None` | | `max_retries` | `int | None` | Maximum number of retries allowed for this tool, set to the agent default if None. | `None` | | `name` | `str | None` | Name of the tool, inferred from the function if None. | `None` | | `description` | `str | None` | Description of the tool, inferred from the function if None. | `None` | | `prepare` | `ToolPrepareFunc[AgentDepsT] | None` | custom method to prepare the tool definition for each step, return None to omit this tool from a given step. This is useful if you want to customise a tool at call time, or omit it completely from a step. See ToolPrepareFunc. | `None` | | `docstring_format` | `DocstringFormat` | The format of the docstring, see DocstringFormat. Defaults to 'auto', such that the format is inferred from the structure of the docstring. | `'auto'` | | `require_parameter_descriptions` | `bool` | If True, raise an error if a parameter description is missing. Defaults to False. | `False` | | `schema_generator` | `type[GenerateJsonSchema]` | The JSON schema generator class to use. Defaults to GenerateToolJsonSchema. | `GenerateToolJsonSchema` | | `strict` | `bool | None` | Whether to enforce JSON schema compliance (only affects OpenAI). See ToolDefinition for more info. | `None` | | `function_schema` | `FunctionSchema | None` | The function schema to use for the tool. If not provided, it will be generated. | `None` | Source code in `pydantic_ai_slim/pydantic_ai/tools.py` ````python def __init__( self, function: ToolFuncEither[AgentDepsT], *, takes_ctx: bool | None = None, max_retries: int | None = None, name: str | None = None, description: str | None = None, prepare: ToolPrepareFunc[AgentDepsT] | None = None, docstring_format: DocstringFormat = 'auto', require_parameter_descriptions: bool = False, schema_generator: type[GenerateJsonSchema] = GenerateToolJsonSchema, strict: bool | None = None, function_schema: _function_schema.FunctionSchema | None = None, ): """Create a new tool instance. Example usage: ```python {noqa="I001"} from pydantic_ai import Agent, RunContext, Tool async def my_tool(ctx: RunContext[int], x: int, y: int) -> str: return f'{ctx.deps} {x} {y}' agent = Agent('test', tools=[Tool(my_tool)]) ``` or with a custom prepare method: ```python {noqa="I001"} from typing import Union from pydantic_ai import Agent, RunContext, Tool from pydantic_ai.tools import ToolDefinition async def my_tool(ctx: RunContext[int], x: int, y: int) -> str: return f'{ctx.deps} {x} {y}' async def prep_my_tool( ctx: RunContext[int], tool_def: ToolDefinition ) -> Union[ToolDefinition, None]: # only register the tool if `deps == 42` if ctx.deps == 42: return tool_def agent = Agent('test', tools=[Tool(my_tool, prepare=prep_my_tool)]) ``` Args: function: The Python function to call as the tool. takes_ctx: Whether the function takes a [`RunContext`][pydantic_ai.tools.RunContext] first argument, this is inferred if unset. max_retries: Maximum number of retries allowed for this tool, set to the agent default if `None`. name: Name of the tool, inferred from the function if `None`. description: Description of the tool, inferred from the function if `None`. prepare: custom method to prepare the tool definition for each step, return `None` to omit this tool from a given step. This is useful if you want to customise a tool at call time, or omit it completely from a step. See [`ToolPrepareFunc`][pydantic_ai.tools.ToolPrepareFunc]. docstring_format: The format of the docstring, see [`DocstringFormat`][pydantic_ai.tools.DocstringFormat]. Defaults to `'auto'`, such that the format is inferred from the structure of the docstring. require_parameter_descriptions: If True, raise an error if a parameter description is missing. Defaults to False. schema_generator: The JSON schema generator class to use. Defaults to `GenerateToolJsonSchema`. strict: Whether to enforce JSON schema compliance (only affects OpenAI). See [`ToolDefinition`][pydantic_ai.tools.ToolDefinition] for more info. function_schema: The function schema to use for the tool. If not provided, it will be generated. """ self.function = function self.function_schema = function_schema or _function_schema.function_schema( function, schema_generator, takes_ctx=takes_ctx, docstring_format=docstring_format, require_parameter_descriptions=require_parameter_descriptions, ) self.takes_ctx = self.function_schema.takes_ctx self.max_retries = max_retries self.name = name or function.__name__ self.description = description or self.function_schema.description self.prepare = prepare self.docstring_format = docstring_format self.require_parameter_descriptions = require_parameter_descriptions self.strict = strict ```` #### function_schema ```python function_schema: FunctionSchema = ( function_schema or function_schema( function, schema_generator, takes_ctx=takes_ctx, docstring_format=docstring_format, require_parameter_descriptions=require_parameter_descriptions, ) ) ``` The base JSON schema for the tool's parameters. This schema may be modified by the `prepare` function or by the Model class prior to including it in an API request. #### from_schema ```python from_schema( function: Callable[..., Any], name: str, description: str, json_schema: JsonSchemaValue, ) -> Self ``` Creates a Pydantic tool from a function and a JSON schema. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `function` | `Callable[..., Any]` | The function to call. This will be called with keywords only, and no validation of the arguments will be performed. | *required* | | `name` | `str` | The unique name of the tool that clearly communicates its purpose | *required* | | `description` | `str` | Used to tell the model how/when/why to use the tool. You can provide few-shot examples as a part of the description. | *required* | | `json_schema` | `JsonSchemaValue` | The schema for the function arguments | *required* | Returns: | Type | Description | | --- | --- | | `Self` | A Pydantic tool that calls the function | Source code in `pydantic_ai_slim/pydantic_ai/tools.py` ```python @classmethod def from_schema( cls, function: Callable[..., Any], name: str, description: str, json_schema: JsonSchemaValue, ) -> Self: """Creates a Pydantic tool from a function and a JSON schema. Args: function: The function to call. This will be called with keywords only, and no validation of the arguments will be performed. name: The unique name of the tool that clearly communicates its purpose description: Used to tell the model how/when/why to use the tool. You can provide few-shot examples as a part of the description. json_schema: The schema for the function arguments Returns: A Pydantic tool that calls the function """ function_schema = _function_schema.FunctionSchema( function=function, description=description, validator=SchemaValidator(schema=core_schema.any_schema()), json_schema=json_schema, takes_ctx=False, is_async=_utils.is_async_callable(function), ) return cls( function, takes_ctx=False, name=name, description=description, function_schema=function_schema, ) ``` #### prepare_tool_def ```python prepare_tool_def( ctx: RunContext[AgentDepsT], ) -> ToolDefinition | None ``` Get the tool definition. By default, this method creates a tool definition, then either returns it, or calls `self.prepare` if it's set. Returns: | Type | Description | | --- | --- | | `ToolDefinition | None` | return a ToolDefinition or None if the tools should not be registered for this run. | Source code in `pydantic_ai_slim/pydantic_ai/tools.py` ```python async def prepare_tool_def(self, ctx: RunContext[AgentDepsT]) -> ToolDefinition | None: """Get the tool definition. By default, this method creates a tool definition, then either returns it, or calls `self.prepare` if it's set. Returns: return a `ToolDefinition` or `None` if the tools should not be registered for this run. """ tool_def = ToolDefinition( name=self.name, description=self.description, parameters_json_schema=self.function_schema.json_schema, strict=self.strict, ) if self.prepare is not None: return await self.prepare(ctx, tool_def) else: return tool_def ``` #### run ```python run( message: ToolCallPart, run_context: RunContext[AgentDepsT], tracer: Tracer, include_content: bool = False, ) -> ToolReturnPart | RetryPromptPart ``` Run the tool function asynchronously. This method wraps `_run` in an OpenTelemetry span. See . Source code in `pydantic_ai_slim/pydantic_ai/tools.py` ```python async def run( self, message: _messages.ToolCallPart, run_context: RunContext[AgentDepsT], tracer: Tracer, include_content: bool = False, ) -> _messages.ToolReturnPart | _messages.RetryPromptPart: """Run the tool function asynchronously. This method wraps `_run` in an OpenTelemetry span. See . """ span_attributes = { 'gen_ai.tool.name': self.name, # NOTE: this means `gen_ai.tool.call.id` will be included even if it was generated by pydantic-ai 'gen_ai.tool.call.id': message.tool_call_id, **({'tool_arguments': message.args_as_json_str()} if include_content else {}), 'logfire.msg': f'running tool: {self.name}', # add the JSON schema so these attributes are formatted nicely in Logfire 'logfire.json_schema': json.dumps( { 'type': 'object', 'properties': { **( { 'tool_arguments': {'type': 'object'}, 'tool_response': {'type': 'object'}, } if include_content else {} ), 'gen_ai.tool.name': {}, 'gen_ai.tool.call.id': {}, }, } ), } with tracer.start_as_current_span('running tool', attributes=span_attributes) as span: response = await self._run(message, run_context) if include_content and span.is_recording(): span.set_attribute( 'tool_response', response.model_response_str() if isinstance(response, ToolReturnPart) else response.model_response(), ) return response ``` ### ObjectJsonSchema ```python ObjectJsonSchema: TypeAlias = dict[str, Any] ``` Type representing JSON schema of an object, e.g. where `"type": "object"`. This type is used to define tools parameters (aka arguments) in ToolDefinition. With PEP-728 this should be a TypedDict with `type: Literal['object']`, and `extra_parts=Any` ### ToolDefinition Definition of a tool passed to a model. This is used for both function tools and output tools. Source code in `pydantic_ai_slim/pydantic_ai/tools.py` ```python @dataclass(repr=False) class ToolDefinition: """Definition of a tool passed to a model. This is used for both function tools and output tools. """ name: str """The name of the tool.""" description: str """The description of the tool.""" parameters_json_schema: ObjectJsonSchema """The JSON schema for the tool's parameters.""" outer_typed_dict_key: str | None = None """The key in the outer [TypedDict] that wraps an output tool. This will only be set for output tools which don't have an `object` JSON schema. """ strict: bool | None = None """Whether to enforce (vendor-specific) strict JSON schema validation for tool calls. Setting this to `True` while using a supported model generally imposes some restrictions on the tool's JSON schema in exchange for guaranteeing the API responses strictly match that schema. When `False`, the model may be free to generate other properties or types (depending on the vendor). When `None` (the default), the value will be inferred based on the compatibility of the parameters_json_schema. Note: this is currently only supported by OpenAI models. """ __repr__ = _utils.dataclasses_no_defaults_repr ``` #### name ```python name: str ``` The name of the tool. #### description ```python description: str ``` The description of the tool. #### parameters_json_schema ```python parameters_json_schema: ObjectJsonSchema ``` The JSON schema for the tool's parameters. #### outer_typed_dict_key ```python outer_typed_dict_key: str | None = None ``` The key in the outer [TypedDict] that wraps an output tool. This will only be set for output tools which don't have an `object` JSON schema. #### strict ```python strict: bool | None = None ``` Whether to enforce (vendor-specific) strict JSON schema validation for tool calls. Setting this to `True` while using a supported model generally imposes some restrictions on the tool's JSON schema in exchange for guaranteeing the API responses strictly match that schema. When `False`, the model may be free to generate other properties or types (depending on the vendor). When `None` (the default), the value will be inferred based on the compatibility of the parameters_json_schema. Note: this is currently only supported by OpenAI models. # `pydantic_ai.usage` ### Usage LLM usage associated with a request or run. Responsibility for calculating usage is on the model; PydanticAI simply sums the usage information across requests. You'll need to look up the documentation of the model you're using to convert usage to monetary costs. Source code in `pydantic_ai_slim/pydantic_ai/usage.py` ```python @dataclass(repr=False) class Usage: """LLM usage associated with a request or run. Responsibility for calculating usage is on the model; PydanticAI simply sums the usage information across requests. You'll need to look up the documentation of the model you're using to convert usage to monetary costs. """ requests: int = 0 """Number of requests made to the LLM API.""" request_tokens: int | None = None """Tokens used in processing requests.""" response_tokens: int | None = None """Tokens used in generating responses.""" total_tokens: int | None = None """Total tokens used in the whole run, should generally be equal to `request_tokens + response_tokens`.""" details: dict[str, int] | None = None """Any extra details returned by the model.""" def incr(self, incr_usage: Usage) -> None: """Increment the usage in place. Args: incr_usage: The usage to increment by. """ for f in 'requests', 'request_tokens', 'response_tokens', 'total_tokens': self_value = getattr(self, f) other_value = getattr(incr_usage, f) if self_value is not None or other_value is not None: setattr(self, f, (self_value or 0) + (other_value or 0)) if incr_usage.details: self.details = self.details or {} for key, value in incr_usage.details.items(): self.details[key] = self.details.get(key, 0) + value def __add__(self, other: Usage) -> Usage: """Add two Usages together. This is provided so it's trivial to sum usage information from multiple requests and runs. """ new_usage = copy(self) new_usage.incr(other) return new_usage def opentelemetry_attributes(self) -> dict[str, int]: """Get the token limits as OpenTelemetry attributes.""" result = { 'gen_ai.usage.input_tokens': self.request_tokens, 'gen_ai.usage.output_tokens': self.response_tokens, } for key, value in (self.details or {}).items(): result[f'gen_ai.usage.details.{key}'] = value # pragma: no cover return {k: v for k, v in result.items() if v} def has_values(self) -> bool: """Whether any values are set and non-zero.""" return bool(self.requests or self.request_tokens or self.response_tokens or self.details) __repr__ = _utils.dataclasses_no_defaults_repr ``` #### requests ```python requests: int = 0 ``` Number of requests made to the LLM API. #### request_tokens ```python request_tokens: int | None = None ``` Tokens used in processing requests. #### response_tokens ```python response_tokens: int | None = None ``` Tokens used in generating responses. #### total_tokens ```python total_tokens: int | None = None ``` Total tokens used in the whole run, should generally be equal to `request_tokens + response_tokens`. #### details ```python details: dict[str, int] | None = None ``` Any extra details returned by the model. #### incr ```python incr(incr_usage: Usage) -> None ``` Increment the usage in place. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `incr_usage` | `Usage` | The usage to increment by. | *required* | Source code in `pydantic_ai_slim/pydantic_ai/usage.py` ```python def incr(self, incr_usage: Usage) -> None: """Increment the usage in place. Args: incr_usage: The usage to increment by. """ for f in 'requests', 'request_tokens', 'response_tokens', 'total_tokens': self_value = getattr(self, f) other_value = getattr(incr_usage, f) if self_value is not None or other_value is not None: setattr(self, f, (self_value or 0) + (other_value or 0)) if incr_usage.details: self.details = self.details or {} for key, value in incr_usage.details.items(): self.details[key] = self.details.get(key, 0) + value ``` #### __add__ ```python __add__(other: Usage) -> Usage ``` Add two Usages together. This is provided so it's trivial to sum usage information from multiple requests and runs. Source code in `pydantic_ai_slim/pydantic_ai/usage.py` ```python def __add__(self, other: Usage) -> Usage: """Add two Usages together. This is provided so it's trivial to sum usage information from multiple requests and runs. """ new_usage = copy(self) new_usage.incr(other) return new_usage ``` #### opentelemetry_attributes ```python opentelemetry_attributes() -> dict[str, int] ``` Get the token limits as OpenTelemetry attributes. Source code in `pydantic_ai_slim/pydantic_ai/usage.py` ```python def opentelemetry_attributes(self) -> dict[str, int]: """Get the token limits as OpenTelemetry attributes.""" result = { 'gen_ai.usage.input_tokens': self.request_tokens, 'gen_ai.usage.output_tokens': self.response_tokens, } for key, value in (self.details or {}).items(): result[f'gen_ai.usage.details.{key}'] = value # pragma: no cover return {k: v for k, v in result.items() if v} ``` #### has_values ```python has_values() -> bool ``` Whether any values are set and non-zero. Source code in `pydantic_ai_slim/pydantic_ai/usage.py` ```python def has_values(self) -> bool: """Whether any values are set and non-zero.""" return bool(self.requests or self.request_tokens or self.response_tokens or self.details) ``` ### UsageLimits Limits on model usage. The request count is tracked by pydantic_ai, and the request limit is checked before each request to the model. Token counts are provided in responses from the model, and the token limits are checked after each response. Each of the limits can be set to `None` to disable that limit. Source code in `pydantic_ai_slim/pydantic_ai/usage.py` ```python @dataclass(repr=False) class UsageLimits: """Limits on model usage. The request count is tracked by pydantic_ai, and the request limit is checked before each request to the model. Token counts are provided in responses from the model, and the token limits are checked after each response. Each of the limits can be set to `None` to disable that limit. """ request_limit: int | None = 50 """The maximum number of requests allowed to the model.""" request_tokens_limit: int | None = None """The maximum number of tokens allowed in requests to the model.""" response_tokens_limit: int | None = None """The maximum number of tokens allowed in responses from the model.""" total_tokens_limit: int | None = None """The maximum number of tokens allowed in requests and responses combined.""" def has_token_limits(self) -> bool: """Returns `True` if this instance places any limits on token counts. If this returns `False`, the `check_tokens` method will never raise an error. This is useful because if we have token limits, we need to check them after receiving each streamed message. If there are no limits, we can skip that processing in the streaming response iterator. """ return any( limit is not None for limit in (self.request_tokens_limit, self.response_tokens_limit, self.total_tokens_limit) ) def check_before_request(self, usage: Usage) -> None: """Raises a `UsageLimitExceeded` exception if the next request would exceed the request_limit.""" request_limit = self.request_limit if request_limit is not None and usage.requests >= request_limit: raise UsageLimitExceeded(f'The next request would exceed the request_limit of {request_limit}') def check_tokens(self, usage: Usage) -> None: """Raises a `UsageLimitExceeded` exception if the usage exceeds any of the token limits.""" request_tokens = usage.request_tokens or 0 if self.request_tokens_limit is not None and request_tokens > self.request_tokens_limit: raise UsageLimitExceeded( f'Exceeded the request_tokens_limit of {self.request_tokens_limit} ({request_tokens=})' ) response_tokens = usage.response_tokens or 0 if self.response_tokens_limit is not None and response_tokens > self.response_tokens_limit: raise UsageLimitExceeded( f'Exceeded the response_tokens_limit of {self.response_tokens_limit} ({response_tokens=})' ) total_tokens = usage.total_tokens or 0 if self.total_tokens_limit is not None and total_tokens > self.total_tokens_limit: raise UsageLimitExceeded(f'Exceeded the total_tokens_limit of {self.total_tokens_limit} ({total_tokens=})') __repr__ = _utils.dataclasses_no_defaults_repr ``` #### request_limit ```python request_limit: int | None = 50 ``` The maximum number of requests allowed to the model. #### request_tokens_limit ```python request_tokens_limit: int | None = None ``` The maximum number of tokens allowed in requests to the model. #### response_tokens_limit ```python response_tokens_limit: int | None = None ``` The maximum number of tokens allowed in responses from the model. #### total_tokens_limit ```python total_tokens_limit: int | None = None ``` The maximum number of tokens allowed in requests and responses combined. #### has_token_limits ```python has_token_limits() -> bool ``` Returns `True` if this instance places any limits on token counts. If this returns `False`, the `check_tokens` method will never raise an error. This is useful because if we have token limits, we need to check them after receiving each streamed message. If there are no limits, we can skip that processing in the streaming response iterator. Source code in `pydantic_ai_slim/pydantic_ai/usage.py` ```python def has_token_limits(self) -> bool: """Returns `True` if this instance places any limits on token counts. If this returns `False`, the `check_tokens` method will never raise an error. This is useful because if we have token limits, we need to check them after receiving each streamed message. If there are no limits, we can skip that processing in the streaming response iterator. """ return any( limit is not None for limit in (self.request_tokens_limit, self.response_tokens_limit, self.total_tokens_limit) ) ``` #### check_before_request ```python check_before_request(usage: Usage) -> None ``` Raises a `UsageLimitExceeded` exception if the next request would exceed the request_limit. Source code in `pydantic_ai_slim/pydantic_ai/usage.py` ```python def check_before_request(self, usage: Usage) -> None: """Raises a `UsageLimitExceeded` exception if the next request would exceed the request_limit.""" request_limit = self.request_limit if request_limit is not None and usage.requests >= request_limit: raise UsageLimitExceeded(f'The next request would exceed the request_limit of {request_limit}') ``` #### check_tokens ```python check_tokens(usage: Usage) -> None ``` Raises a `UsageLimitExceeded` exception if the usage exceeds any of the token limits. Source code in `pydantic_ai_slim/pydantic_ai/usage.py` ```python def check_tokens(self, usage: Usage) -> None: """Raises a `UsageLimitExceeded` exception if the usage exceeds any of the token limits.""" request_tokens = usage.request_tokens or 0 if self.request_tokens_limit is not None and request_tokens > self.request_tokens_limit: raise UsageLimitExceeded( f'Exceeded the request_tokens_limit of {self.request_tokens_limit} ({request_tokens=})' ) response_tokens = usage.response_tokens or 0 if self.response_tokens_limit is not None and response_tokens > self.response_tokens_limit: raise UsageLimitExceeded( f'Exceeded the response_tokens_limit of {self.response_tokens_limit} ({response_tokens=})' ) total_tokens = usage.total_tokens or 0 if self.total_tokens_limit is not None and total_tokens > self.total_tokens_limit: raise UsageLimitExceeded(f'Exceeded the total_tokens_limit of {self.total_tokens_limit} ({total_tokens=})') ``` # `pydantic_ai.models.anthropic` ## Setup For details on how to set up authentication with this model, see [model configuration for Anthropic](../../../models/anthropic/). ### LatestAnthropicModelNames ```python LatestAnthropicModelNames = ModelParam ``` Latest Anthropic models. ### AnthropicModelName ```python AnthropicModelName = Union[str, LatestAnthropicModelNames] ``` Possible Anthropic model names. Since Anthropic supports a variety of date-stamped models, we explicitly list the latest models but allow any name in the type hints. See [the Anthropic docs](https://docs.anthropic.com/en/docs/about-claude/models) for a full list. ### AnthropicModelSettings Bases: `ModelSettings` Settings used for an Anthropic model request. Source code in `pydantic_ai_slim/pydantic_ai/models/anthropic.py` ```python class AnthropicModelSettings(ModelSettings, total=False): """Settings used for an Anthropic model request.""" # ALL FIELDS MUST BE `anthropic_` PREFIXED SO YOU CAN MERGE THEM WITH OTHER MODELS. anthropic_metadata: BetaMetadataParam """An object describing metadata about the request. Contains `user_id`, an external identifier for the user who is associated with the request. """ anthropic_thinking: BetaThinkingConfigParam """Determine whether the model should generate a thinking block. See [the Anthropic docs](https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking) for more information. """ ``` #### anthropic_metadata ```python anthropic_metadata: BetaMetadataParam ``` An object describing metadata about the request. Contains `user_id`, an external identifier for the user who is associated with the request. #### anthropic_thinking ```python anthropic_thinking: BetaThinkingConfigParam ``` Determine whether the model should generate a thinking block. See [the Anthropic docs](https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking) for more information. ### AnthropicModel Bases: `Model` A model that uses the Anthropic API. Internally, this uses the [Anthropic Python client](https://github.com/anthropics/anthropic-sdk-python) to interact with the API. Apart from `__init__`, all methods are private or match those of the base class. Source code in `pydantic_ai_slim/pydantic_ai/models/anthropic.py` ```python @dataclass(init=False) class AnthropicModel(Model): """A model that uses the Anthropic API. Internally, this uses the [Anthropic Python client](https://github.com/anthropics/anthropic-sdk-python) to interact with the API. Apart from `__init__`, all methods are private or match those of the base class. """ client: AsyncAnthropic = field(repr=False) _model_name: AnthropicModelName = field(repr=False) _system: str = field(default='anthropic', repr=False) def __init__( self, model_name: AnthropicModelName, *, provider: Literal['anthropic'] | Provider[AsyncAnthropic] = 'anthropic', profile: ModelProfileSpec | None = None, ): """Initialize an Anthropic model. Args: model_name: The name of the Anthropic model to use. List of model names available [here](https://docs.anthropic.com/en/docs/about-claude/models). provider: The provider to use for the Anthropic API. Can be either the string 'anthropic' or an instance of `Provider[AsyncAnthropic]`. If not provided, the other parameters will be used. profile: The model profile to use. Defaults to a profile picked by the provider based on the model name. """ self._model_name = model_name if isinstance(provider, str): provider = infer_provider(provider) self.client = provider.client self._profile = profile or provider.model_profile @property def base_url(self) -> str: return str(self.client.base_url) async def request( self, messages: list[ModelMessage], model_settings: ModelSettings | None, model_request_parameters: ModelRequestParameters, ) -> ModelResponse: check_allow_model_requests() response = await self._messages_create( messages, False, cast(AnthropicModelSettings, model_settings or {}), model_request_parameters ) model_response = self._process_response(response) model_response.usage.requests = 1 return model_response @asynccontextmanager async def request_stream( self, messages: list[ModelMessage], model_settings: ModelSettings | None, model_request_parameters: ModelRequestParameters, ) -> AsyncIterator[StreamedResponse]: check_allow_model_requests() response = await self._messages_create( messages, True, cast(AnthropicModelSettings, model_settings or {}), model_request_parameters ) async with response: yield await self._process_streamed_response(response) @property def model_name(self) -> AnthropicModelName: """The model name.""" return self._model_name @property def system(self) -> str: """The system / model provider.""" return self._system @overload async def _messages_create( self, messages: list[ModelMessage], stream: Literal[True], model_settings: AnthropicModelSettings, model_request_parameters: ModelRequestParameters, ) -> AsyncStream[BetaRawMessageStreamEvent]: pass @overload async def _messages_create( self, messages: list[ModelMessage], stream: Literal[False], model_settings: AnthropicModelSettings, model_request_parameters: ModelRequestParameters, ) -> BetaMessage: pass async def _messages_create( self, messages: list[ModelMessage], stream: bool, model_settings: AnthropicModelSettings, model_request_parameters: ModelRequestParameters, ) -> BetaMessage | AsyncStream[BetaRawMessageStreamEvent]: # standalone function to make it easier to override tools = self._get_tools(model_request_parameters) tool_choice: BetaToolChoiceParam | None if not tools: tool_choice = None else: if not model_request_parameters.allow_text_output: tool_choice = {'type': 'any'} else: tool_choice = {'type': 'auto'} if (allow_parallel_tool_calls := model_settings.get('parallel_tool_calls')) is not None: tool_choice['disable_parallel_tool_use'] = not allow_parallel_tool_calls system_prompt, anthropic_messages = await self._map_message(messages) try: extra_headers = model_settings.get('extra_headers', {}) extra_headers.setdefault('User-Agent', get_user_agent()) return await self.client.beta.messages.create( max_tokens=model_settings.get('max_tokens', 4096), system=system_prompt or NOT_GIVEN, messages=anthropic_messages, model=self._model_name, tools=tools or NOT_GIVEN, tool_choice=tool_choice or NOT_GIVEN, stream=stream, thinking=model_settings.get('anthropic_thinking', NOT_GIVEN), stop_sequences=model_settings.get('stop_sequences', NOT_GIVEN), temperature=model_settings.get('temperature', NOT_GIVEN), top_p=model_settings.get('top_p', NOT_GIVEN), timeout=model_settings.get('timeout', NOT_GIVEN), metadata=model_settings.get('anthropic_metadata', NOT_GIVEN), extra_headers=extra_headers, extra_body=model_settings.get('extra_body'), ) except APIStatusError as e: if (status_code := e.status_code) >= 400: raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=e.body) from e raise # pragma: lax no cover def _process_response(self, response: BetaMessage) -> ModelResponse: """Process a non-streamed response, and prepare a message to return.""" items: list[ModelResponsePart] = [] for item in response.content: if isinstance(item, BetaTextBlock): items.append(TextPart(content=item.text)) elif isinstance(item, BetaRedactedThinkingBlock): # pragma: no cover warnings.warn( 'PydanticAI currently does not handle redacted thinking blocks. ' 'If you have a suggestion on how we should handle them, please open an issue.', UserWarning, ) elif isinstance(item, BetaThinkingBlock): items.append(ThinkingPart(content=item.thinking, signature=item.signature)) else: assert isinstance(item, BetaToolUseBlock), f'unexpected item type {type(item)}' items.append( ToolCallPart( tool_name=item.name, args=cast(dict[str, Any], item.input), tool_call_id=item.id, ) ) return ModelResponse(items, usage=_map_usage(response), model_name=response.model, vendor_id=response.id) async def _process_streamed_response(self, response: AsyncStream[BetaRawMessageStreamEvent]) -> StreamedResponse: peekable_response = _utils.PeekableAsyncStream(response) first_chunk = await peekable_response.peek() if isinstance(first_chunk, _utils.Unset): raise UnexpectedModelBehavior('Streamed response ended without content or tool calls') # pragma: no cover # Since Anthropic doesn't provide a timestamp in the message, we'll use the current time timestamp = datetime.now(tz=timezone.utc) return AnthropicStreamedResponse( _model_name=self._model_name, _response=peekable_response, _timestamp=timestamp ) def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[BetaToolParam]: tools = [self._map_tool_definition(r) for r in model_request_parameters.function_tools] if model_request_parameters.output_tools: tools += [self._map_tool_definition(r) for r in model_request_parameters.output_tools] return tools async def _map_message(self, messages: list[ModelMessage]) -> tuple[str, list[BetaMessageParam]]: # noqa: C901 """Just maps a `pydantic_ai.Message` to a `anthropic.types.MessageParam`.""" system_prompt_parts: list[str] = [] anthropic_messages: list[BetaMessageParam] = [] for m in messages: if isinstance(m, ModelRequest): user_content_params: list[BetaContentBlockParam] = [] for request_part in m.parts: if isinstance(request_part, SystemPromptPart): system_prompt_parts.append(request_part.content) elif isinstance(request_part, UserPromptPart): async for content in self._map_user_prompt(request_part): user_content_params.append(content) elif isinstance(request_part, ToolReturnPart): tool_result_block_param = BetaToolResultBlockParam( tool_use_id=_guard_tool_call_id(t=request_part), type='tool_result', content=request_part.model_response_str(), is_error=False, ) user_content_params.append(tool_result_block_param) elif isinstance(request_part, RetryPromptPart): # pragma: no branch if request_part.tool_name is None: text = request_part.model_response() # pragma: no cover retry_param = BetaTextBlockParam(type='text', text=text) # pragma: no cover else: retry_param = BetaToolResultBlockParam( tool_use_id=_guard_tool_call_id(t=request_part), type='tool_result', content=request_part.model_response(), is_error=True, ) user_content_params.append(retry_param) if len(user_content_params) > 0: anthropic_messages.append(BetaMessageParam(role='user', content=user_content_params)) elif isinstance(m, ModelResponse): assistant_content_params: list[BetaTextBlockParam | BetaToolUseBlockParam | BetaThinkingBlockParam] = [] for response_part in m.parts: if isinstance(response_part, TextPart): if response_part.content: # Only add non-empty text assistant_content_params.append(BetaTextBlockParam(text=response_part.content, type='text')) elif isinstance(response_part, ThinkingPart): # NOTE: We only send thinking part back for Anthropic, otherwise they raise an error. if response_part.signature is not None: # pragma: no branch assistant_content_params.append( BetaThinkingBlockParam( thinking=response_part.content, signature=response_part.signature, type='thinking' ) ) else: tool_use_block_param = BetaToolUseBlockParam( id=_guard_tool_call_id(t=response_part), type='tool_use', name=response_part.tool_name, input=response_part.args_as_dict(), ) assistant_content_params.append(tool_use_block_param) if len(assistant_content_params) > 0: anthropic_messages.append(BetaMessageParam(role='assistant', content=assistant_content_params)) else: assert_never(m) system_prompt = '\n\n'.join(system_prompt_parts) if instructions := self._get_instructions(messages): system_prompt = f'{instructions}\n\n{system_prompt}' return system_prompt, anthropic_messages @staticmethod async def _map_user_prompt( part: UserPromptPart, ) -> AsyncGenerator[BetaContentBlockParam]: if isinstance(part.content, str): if part.content: # Only yield non-empty text yield BetaTextBlockParam(text=part.content, type='text') else: for item in part.content: if isinstance(item, str): if item: # Only yield non-empty text yield BetaTextBlockParam(text=item, type='text') elif isinstance(item, BinaryContent): if item.is_image: yield BetaImageBlockParam( source={'data': io.BytesIO(item.data), 'media_type': item.media_type, 'type': 'base64'}, # type: ignore type='image', ) elif item.media_type == 'application/pdf': yield BetaBase64PDFBlockParam( source=BetaBase64PDFSourceParam( data=io.BytesIO(item.data), media_type='application/pdf', type='base64', ), type='document', ) else: raise RuntimeError('Only images and PDFs are supported for binary content') elif isinstance(item, ImageUrl): yield BetaImageBlockParam(source={'type': 'url', 'url': item.url}, type='image') elif isinstance(item, DocumentUrl): if item.media_type == 'application/pdf': yield BetaBase64PDFBlockParam(source={'url': item.url, 'type': 'url'}, type='document') elif item.media_type == 'text/plain': downloaded_item = await download_item(item, data_format='text') yield BetaBase64PDFBlockParam( source=BetaPlainTextSourceParam( data=downloaded_item['data'], media_type=item.media_type, type='text' ), type='document', ) else: # pragma: no cover raise RuntimeError(f'Unsupported media type: {item.media_type}') else: raise RuntimeError(f'Unsupported content type: {type(item)}') # pragma: no cover @staticmethod def _map_tool_definition(f: ToolDefinition) -> BetaToolParam: return { 'name': f.name, 'description': f.description, 'input_schema': f.parameters_json_schema, } ``` #### __init__ ```python __init__( model_name: AnthropicModelName, *, provider: ( Literal["anthropic"] | Provider[AsyncAnthropic] ) = "anthropic", profile: ModelProfileSpec | None = None ) ``` Initialize an Anthropic model. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `model_name` | `AnthropicModelName` | The name of the Anthropic model to use. List of model names available here. | *required* | | `provider` | `Literal['anthropic'] | Provider[AsyncAnthropic]` | The provider to use for the Anthropic API. Can be either the string 'anthropic' or an instance of Provider[AsyncAnthropic]. If not provided, the other parameters will be used. | `'anthropic'` | | `profile` | `ModelProfileSpec | None` | The model profile to use. Defaults to a profile picked by the provider based on the model name. | `None` | Source code in `pydantic_ai_slim/pydantic_ai/models/anthropic.py` ```python def __init__( self, model_name: AnthropicModelName, *, provider: Literal['anthropic'] | Provider[AsyncAnthropic] = 'anthropic', profile: ModelProfileSpec | None = None, ): """Initialize an Anthropic model. Args: model_name: The name of the Anthropic model to use. List of model names available [here](https://docs.anthropic.com/en/docs/about-claude/models). provider: The provider to use for the Anthropic API. Can be either the string 'anthropic' or an instance of `Provider[AsyncAnthropic]`. If not provided, the other parameters will be used. profile: The model profile to use. Defaults to a profile picked by the provider based on the model name. """ self._model_name = model_name if isinstance(provider, str): provider = infer_provider(provider) self.client = provider.client self._profile = profile or provider.model_profile ``` #### model_name ```python model_name: AnthropicModelName ``` The model name. #### system ```python system: str ``` The system / model provider. ### AnthropicStreamedResponse Bases: `StreamedResponse` Implementation of `StreamedResponse` for Anthropic models. Source code in `pydantic_ai_slim/pydantic_ai/models/anthropic.py` ```python @dataclass class AnthropicStreamedResponse(StreamedResponse): """Implementation of `StreamedResponse` for Anthropic models.""" _model_name: AnthropicModelName _response: AsyncIterable[BetaRawMessageStreamEvent] _timestamp: datetime async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: current_block: BetaContentBlock | None = None async for event in self._response: self._usage += _map_usage(event) if isinstance(event, BetaRawContentBlockStartEvent): current_block = event.content_block if isinstance(current_block, BetaTextBlock) and current_block.text: yield self._parts_manager.handle_text_delta(vendor_part_id='content', content=current_block.text) elif isinstance(current_block, BetaThinkingBlock): yield self._parts_manager.handle_thinking_delta( vendor_part_id='thinking', content=current_block.thinking, signature=current_block.signature, ) elif isinstance(current_block, BetaToolUseBlock): maybe_event = self._parts_manager.handle_tool_call_delta( vendor_part_id=current_block.id, tool_name=current_block.name, args=cast(dict[str, Any], current_block.input) or None, tool_call_id=current_block.id, ) if maybe_event is not None: # pragma: no branch yield maybe_event elif isinstance(event, BetaRawContentBlockDeltaEvent): if isinstance(event.delta, BetaTextDelta): yield self._parts_manager.handle_text_delta(vendor_part_id='content', content=event.delta.text) elif isinstance(event.delta, BetaThinkingDelta): yield self._parts_manager.handle_thinking_delta( vendor_part_id='thinking', content=event.delta.thinking ) elif isinstance(event.delta, BetaSignatureDelta): yield self._parts_manager.handle_thinking_delta( vendor_part_id='thinking', signature=event.delta.signature ) elif ( current_block and event.delta.type == 'input_json_delta' and isinstance(current_block, BetaToolUseBlock) ): # pragma: no branch maybe_event = self._parts_manager.handle_tool_call_delta( vendor_part_id=current_block.id, tool_name='', args=event.delta.partial_json, tool_call_id=current_block.id, ) if maybe_event is not None: # pragma: no branch yield maybe_event elif isinstance(event, (BetaRawContentBlockStopEvent, BetaRawMessageStopEvent)): current_block = None @property def model_name(self) -> AnthropicModelName: """Get the model name of the response.""" return self._model_name @property def timestamp(self) -> datetime: """Get the timestamp of the response.""" return self._timestamp ``` #### model_name ```python model_name: AnthropicModelName ``` Get the model name of the response. #### timestamp ```python timestamp: datetime ``` Get the timestamp of the response. # `pydantic_ai.models` Logic related to making requests to an LLM. The aim here is to make a common interface for different LLMs, so that the rest of the code can be agnostic to the specific LLM being used. ### KnownModelName ```python KnownModelName = TypeAliasType( "KnownModelName", Literal[ "anthropic:claude-2.0", "anthropic:claude-2.1", "anthropic:claude-3-5-haiku-20241022", "anthropic:claude-3-5-haiku-latest", "anthropic:claude-3-5-sonnet-20240620", "anthropic:claude-3-5-sonnet-20241022", "anthropic:claude-3-5-sonnet-latest", "anthropic:claude-3-7-sonnet-20250219", "anthropic:claude-3-7-sonnet-latest", "anthropic:claude-3-haiku-20240307", "anthropic:claude-3-opus-20240229", "anthropic:claude-3-opus-latest", "anthropic:claude-3-sonnet-20240229", "anthropic:claude-4-opus-20250514", "anthropic:claude-4-sonnet-20250514", "anthropic:claude-opus-4-0", "anthropic:claude-opus-4-20250514", "anthropic:claude-sonnet-4-0", "anthropic:claude-sonnet-4-20250514", "bedrock:amazon.titan-tg1-large", "bedrock:amazon.titan-text-lite-v1", "bedrock:amazon.titan-text-express-v1", "bedrock:us.amazon.nova-pro-v1:0", "bedrock:us.amazon.nova-lite-v1:0", "bedrock:us.amazon.nova-micro-v1:0", "bedrock:anthropic.claude-3-5-sonnet-20241022-v2:0", "bedrock:us.anthropic.claude-3-5-sonnet-20241022-v2:0", "bedrock:anthropic.claude-3-5-haiku-20241022-v1:0", "bedrock:us.anthropic.claude-3-5-haiku-20241022-v1:0", "bedrock:anthropic.claude-instant-v1", "bedrock:anthropic.claude-v2:1", "bedrock:anthropic.claude-v2", "bedrock:anthropic.claude-3-sonnet-20240229-v1:0", "bedrock:us.anthropic.claude-3-sonnet-20240229-v1:0", "bedrock:anthropic.claude-3-haiku-20240307-v1:0", "bedrock:us.anthropic.claude-3-haiku-20240307-v1:0", "bedrock:anthropic.claude-3-opus-20240229-v1:0", "bedrock:us.anthropic.claude-3-opus-20240229-v1:0", "bedrock:anthropic.claude-3-5-sonnet-20240620-v1:0", "bedrock:us.anthropic.claude-3-5-sonnet-20240620-v1:0", "bedrock:anthropic.claude-3-7-sonnet-20250219-v1:0", "bedrock:us.anthropic.claude-3-7-sonnet-20250219-v1:0", "bedrock:anthropic.claude-opus-4-20250514-v1:0", "bedrock:us.anthropic.claude-opus-4-20250514-v1:0", "bedrock:anthropic.claude-sonnet-4-20250514-v1:0", "bedrock:us.anthropic.claude-sonnet-4-20250514-v1:0", "bedrock:cohere.command-text-v14", "bedrock:cohere.command-r-v1:0", "bedrock:cohere.command-r-plus-v1:0", "bedrock:cohere.command-light-text-v14", "bedrock:meta.llama3-8b-instruct-v1:0", "bedrock:meta.llama3-70b-instruct-v1:0", "bedrock:meta.llama3-1-8b-instruct-v1:0", "bedrock:us.meta.llama3-1-8b-instruct-v1:0", "bedrock:meta.llama3-1-70b-instruct-v1:0", "bedrock:us.meta.llama3-1-70b-instruct-v1:0", "bedrock:meta.llama3-1-405b-instruct-v1:0", "bedrock:us.meta.llama3-2-11b-instruct-v1:0", "bedrock:us.meta.llama3-2-90b-instruct-v1:0", "bedrock:us.meta.llama3-2-1b-instruct-v1:0", "bedrock:us.meta.llama3-2-3b-instruct-v1:0", "bedrock:us.meta.llama3-3-70b-instruct-v1:0", "bedrock:mistral.mistral-7b-instruct-v0:2", "bedrock:mistral.mixtral-8x7b-instruct-v0:1", "bedrock:mistral.mistral-large-2402-v1:0", "bedrock:mistral.mistral-large-2407-v1:0", "claude-2.0", "claude-2.1", "claude-3-5-haiku-20241022", "claude-3-5-haiku-latest", "claude-3-5-sonnet-20240620", "claude-3-5-sonnet-20241022", "claude-3-5-sonnet-latest", "claude-3-7-sonnet-20250219", "claude-3-7-sonnet-latest", "claude-3-haiku-20240307", "claude-3-opus-20240229", "claude-3-opus-latest", "claude-3-sonnet-20240229", "claude-4-opus-20250514", "claude-4-sonnet-20250514", "claude-opus-4-0", "claude-opus-4-20250514", "claude-sonnet-4-0", "claude-sonnet-4-20250514", "cohere:c4ai-aya-expanse-32b", "cohere:c4ai-aya-expanse-8b", "cohere:command", "cohere:command-light", "cohere:command-light-nightly", "cohere:command-nightly", "cohere:command-r", "cohere:command-r-03-2024", "cohere:command-r-08-2024", "cohere:command-r-plus", "cohere:command-r-plus-04-2024", "cohere:command-r-plus-08-2024", "cohere:command-r7b-12-2024", "deepseek:deepseek-chat", "deepseek:deepseek-reasoner", "google-gla:gemini-1.5-flash", "google-gla:gemini-1.5-flash-8b", "google-gla:gemini-1.5-pro", "google-gla:gemini-1.0-pro", "google-gla:gemini-2.0-flash", "google-gla:gemini-2.0-flash-lite-preview-02-05", "google-gla:gemini-2.0-pro-exp-02-05", "google-gla:gemini-2.5-flash-preview-05-20", "google-gla:gemini-2.5-flash", "google-gla:gemini-2.5-flash-lite-preview-06-17", "google-gla:gemini-2.5-pro-exp-03-25", "google-gla:gemini-2.5-pro-preview-05-06", "google-gla:gemini-2.5-pro", "google-vertex:gemini-1.5-flash", "google-vertex:gemini-1.5-flash-8b", "google-vertex:gemini-1.5-pro", "google-vertex:gemini-1.0-pro", "google-vertex:gemini-2.0-flash", "google-vertex:gemini-2.0-flash-lite-preview-02-05", "google-vertex:gemini-2.0-pro-exp-02-05", "google-vertex:gemini-2.5-flash-preview-05-20", "google-vertex:gemini-2.5-flash", "google-vertex:gemini-2.5-flash-lite-preview-06-17", "google-vertex:gemini-2.5-pro-exp-03-25", "google-vertex:gemini-2.5-pro-preview-05-06", "google-vertex:gemini-2.5-pro", "gpt-3.5-turbo", "gpt-3.5-turbo-0125", "gpt-3.5-turbo-0301", "gpt-3.5-turbo-0613", "gpt-3.5-turbo-1106", "gpt-3.5-turbo-16k", "gpt-3.5-turbo-16k-0613", "gpt-4", "gpt-4-0125-preview", "gpt-4-0314", "gpt-4-0613", "gpt-4-1106-preview", "gpt-4-32k", "gpt-4-32k-0314", "gpt-4-32k-0613", "gpt-4-turbo", "gpt-4-turbo-2024-04-09", "gpt-4-turbo-preview", "gpt-4-vision-preview", "gpt-4.1", "gpt-4.1-2025-04-14", "gpt-4.1-mini", "gpt-4.1-mini-2025-04-14", "gpt-4.1-nano", "gpt-4.1-nano-2025-04-14", "gpt-4o", "gpt-4o-2024-05-13", "gpt-4o-2024-08-06", "gpt-4o-2024-11-20", "gpt-4o-audio-preview", "gpt-4o-audio-preview-2024-10-01", "gpt-4o-audio-preview-2024-12-17", "gpt-4o-mini", "gpt-4o-mini-2024-07-18", "gpt-4o-mini-audio-preview", "gpt-4o-mini-audio-preview-2024-12-17", "gpt-4o-mini-search-preview", "gpt-4o-mini-search-preview-2025-03-11", "gpt-4o-search-preview", "gpt-4o-search-preview-2025-03-11", "groq:distil-whisper-large-v3-en", "groq:gemma2-9b-it", "groq:llama-3.3-70b-versatile", "groq:llama-3.1-8b-instant", "groq:llama-guard-3-8b", "groq:llama3-70b-8192", "groq:llama3-8b-8192", "groq:whisper-large-v3", "groq:whisper-large-v3-turbo", "groq:playai-tts", "groq:playai-tts-arabic", "groq:qwen-qwq-32b", "groq:mistral-saba-24b", "groq:qwen-2.5-coder-32b", "groq:qwen-2.5-32b", "groq:deepseek-r1-distill-qwen-32b", "groq:deepseek-r1-distill-llama-70b", "groq:llama-3.3-70b-specdec", "groq:llama-3.2-1b-preview", "groq:llama-3.2-3b-preview", "groq:llama-3.2-11b-vision-preview", "groq:llama-3.2-90b-vision-preview", "heroku:claude-3-5-haiku", "heroku:claude-3-5-sonnet-latest", "heroku:claude-3-7-sonnet", "heroku:claude-4-sonnet", "heroku:claude-3-haiku", "mistral:codestral-latest", "mistral:mistral-large-latest", "mistral:mistral-moderation-latest", "mistral:mistral-small-latest", "o1", "o1-2024-12-17", "o1-mini", "o1-mini-2024-09-12", "o1-preview", "o1-preview-2024-09-12", "o3", "o3-2025-04-16", "o3-mini", "o3-mini-2025-01-31", "openai:chatgpt-4o-latest", "openai:gpt-3.5-turbo", "openai:gpt-3.5-turbo-0125", "openai:gpt-3.5-turbo-0301", "openai:gpt-3.5-turbo-0613", "openai:gpt-3.5-turbo-1106", "openai:gpt-3.5-turbo-16k", "openai:gpt-3.5-turbo-16k-0613", "openai:gpt-4", "openai:gpt-4-0125-preview", "openai:gpt-4-0314", "openai:gpt-4-0613", "openai:gpt-4-1106-preview", "openai:gpt-4-32k", "openai:gpt-4-32k-0314", "openai:gpt-4-32k-0613", "openai:gpt-4-turbo", "openai:gpt-4-turbo-2024-04-09", "openai:gpt-4-turbo-preview", "openai:gpt-4-vision-preview", "openai:gpt-4.1", "openai:gpt-4.1-2025-04-14", "openai:gpt-4.1-mini", "openai:gpt-4.1-mini-2025-04-14", "openai:gpt-4.1-nano", "openai:gpt-4.1-nano-2025-04-14", "openai:gpt-4o", "openai:gpt-4o-2024-05-13", "openai:gpt-4o-2024-08-06", "openai:gpt-4o-2024-11-20", "openai:gpt-4o-audio-preview", "openai:gpt-4o-audio-preview-2024-10-01", "openai:gpt-4o-audio-preview-2024-12-17", "openai:gpt-4o-mini", "openai:gpt-4o-mini-2024-07-18", "openai:gpt-4o-mini-audio-preview", "openai:gpt-4o-mini-audio-preview-2024-12-17", "openai:gpt-4o-mini-search-preview", "openai:gpt-4o-mini-search-preview-2025-03-11", "openai:gpt-4o-search-preview", "openai:gpt-4o-search-preview-2025-03-11", "openai:o1", "openai:o1-2024-12-17", "openai:o1-mini", "openai:o1-mini-2024-09-12", "openai:o1-preview", "openai:o1-preview-2024-09-12", "openai:o3", "openai:o3-2025-04-16", "openai:o3-mini", "openai:o3-mini-2025-01-31", "openai:o4-mini", "openai:o4-mini-2025-04-16", "test", ], ) ``` Known model names that can be used with the `model` parameter of Agent. `KnownModelName` is provided as a concise way to specify a model. ### ModelRequestParameters Configuration for an agent's request to a model, specifically related to tools and output handling. Source code in `pydantic_ai_slim/pydantic_ai/models/__init__.py` ```python @dataclass(repr=False) class ModelRequestParameters: """Configuration for an agent's request to a model, specifically related to tools and output handling.""" function_tools: list[ToolDefinition] = field(default_factory=list) output_mode: OutputMode = 'text' output_object: OutputObjectDefinition | None = None output_tools: list[ToolDefinition] = field(default_factory=list) allow_text_output: bool = True __repr__ = _utils.dataclasses_no_defaults_repr ``` ### Model Bases: `ABC` Abstract class for a model. Source code in `pydantic_ai_slim/pydantic_ai/models/__init__.py` ```python class Model(ABC): """Abstract class for a model.""" _profile: ModelProfileSpec | None = None @abstractmethod async def request( self, messages: list[ModelMessage], model_settings: ModelSettings | None, model_request_parameters: ModelRequestParameters, ) -> ModelResponse: """Make a request to the model.""" raise NotImplementedError() @asynccontextmanager async def request_stream( self, messages: list[ModelMessage], model_settings: ModelSettings | None, model_request_parameters: ModelRequestParameters, ) -> AsyncIterator[StreamedResponse]: """Make a request to the model and return a streaming response.""" # This method is not required, but you need to implement it if you want to support streamed responses raise NotImplementedError(f'Streamed requests not supported by this {self.__class__.__name__}') # yield is required to make this a generator for type checking # noinspection PyUnreachableCode yield # pragma: no cover def customize_request_parameters(self, model_request_parameters: ModelRequestParameters) -> ModelRequestParameters: """Customize the request parameters for the model. This method can be overridden by subclasses to modify the request parameters before sending them to the model. In particular, this method can be used to make modifications to the generated tool JSON schemas if necessary for vendor/model-specific reasons. """ if transformer := self.profile.json_schema_transformer: model_request_parameters = replace( model_request_parameters, function_tools=[_customize_tool_def(transformer, t) for t in model_request_parameters.function_tools], output_tools=[_customize_tool_def(transformer, t) for t in model_request_parameters.output_tools], ) if output_object := model_request_parameters.output_object: model_request_parameters = replace( model_request_parameters, output_object=_customize_output_object(transformer, output_object), ) return model_request_parameters @property @abstractmethod def model_name(self) -> str: """The model name.""" raise NotImplementedError() @cached_property def profile(self) -> ModelProfile: """The model profile.""" _profile = self._profile if callable(_profile): _profile = _profile(self.model_name) if _profile is None: return DEFAULT_PROFILE return _profile @property @abstractmethod def system(self) -> str: """The system / model provider, ex: openai. Use to populate the `gen_ai.system` OpenTelemetry semantic convention attribute, so should use well-known values listed in https://opentelemetry.io/docs/specs/semconv/attributes-registry/gen-ai/#gen-ai-system when applicable. """ raise NotImplementedError() @property def base_url(self) -> str | None: """The base URL for the provider API, if available.""" return None @staticmethod def _get_instructions(messages: list[ModelMessage]) -> str | None: """Get instructions from the first ModelRequest found when iterating messages in reverse. In the case that a "mock" request was generated to include a tool-return part for a result tool, we want to use the instructions from the second-to-most-recent request (which should correspond to the original request that generated the response that resulted in the tool-return part). """ last_two_requests: list[ModelRequest] = [] for message in reversed(messages): if isinstance(message, ModelRequest): last_two_requests.append(message) if len(last_two_requests) == 2: break if message.instructions is not None: return message.instructions # If we don't have two requests, and we didn't already return instructions, there are definitely not any: if len(last_two_requests) != 2: return None most_recent_request = last_two_requests[0] second_most_recent_request = last_two_requests[1] # If we've gotten this far and the most recent request consists of only tool-return parts or retry-prompt parts, # we use the instructions from the second-to-most-recent request. This is necessary because when handling # result tools, we generate a "mock" ModelRequest with a tool-return part for it, and that ModelRequest will not # have the relevant instructions from the agent. # While it's possible that you could have a message history where the most recent request has only tool returns, # I believe there is no way to achieve that would _change_ the instructions without manually crafting the most # recent message. That might make sense in principle for some usage pattern, but it's enough of an edge case # that I think it's not worth worrying about, since you can work around this by inserting another ModelRequest # with no parts at all immediately before the request that has the tool calls (that works because we only look # at the two most recent ModelRequests here). # If you have a use case where this causes pain, please open a GitHub issue and we can discuss alternatives. if all(p.part_kind == 'tool-return' or p.part_kind == 'retry-prompt' for p in most_recent_request.parts): return second_most_recent_request.instructions return None ``` #### request ```python request( messages: list[ModelMessage], model_settings: ModelSettings | None, model_request_parameters: ModelRequestParameters, ) -> ModelResponse ``` Make a request to the model. Source code in `pydantic_ai_slim/pydantic_ai/models/__init__.py` ```python @abstractmethod async def request( self, messages: list[ModelMessage], model_settings: ModelSettings | None, model_request_parameters: ModelRequestParameters, ) -> ModelResponse: """Make a request to the model.""" raise NotImplementedError() ``` #### request_stream ```python request_stream( messages: list[ModelMessage], model_settings: ModelSettings | None, model_request_parameters: ModelRequestParameters, ) -> AsyncIterator[StreamedResponse] ``` Make a request to the model and return a streaming response. Source code in `pydantic_ai_slim/pydantic_ai/models/__init__.py` ```python @asynccontextmanager async def request_stream( self, messages: list[ModelMessage], model_settings: ModelSettings | None, model_request_parameters: ModelRequestParameters, ) -> AsyncIterator[StreamedResponse]: """Make a request to the model and return a streaming response.""" # This method is not required, but you need to implement it if you want to support streamed responses raise NotImplementedError(f'Streamed requests not supported by this {self.__class__.__name__}') # yield is required to make this a generator for type checking # noinspection PyUnreachableCode yield # pragma: no cover ``` #### customize_request_parameters ```python customize_request_parameters( model_request_parameters: ModelRequestParameters, ) -> ModelRequestParameters ``` Customize the request parameters for the model. This method can be overridden by subclasses to modify the request parameters before sending them to the model. In particular, this method can be used to make modifications to the generated tool JSON schemas if necessary for vendor/model-specific reasons. Source code in `pydantic_ai_slim/pydantic_ai/models/__init__.py` ```python def customize_request_parameters(self, model_request_parameters: ModelRequestParameters) -> ModelRequestParameters: """Customize the request parameters for the model. This method can be overridden by subclasses to modify the request parameters before sending them to the model. In particular, this method can be used to make modifications to the generated tool JSON schemas if necessary for vendor/model-specific reasons. """ if transformer := self.profile.json_schema_transformer: model_request_parameters = replace( model_request_parameters, function_tools=[_customize_tool_def(transformer, t) for t in model_request_parameters.function_tools], output_tools=[_customize_tool_def(transformer, t) for t in model_request_parameters.output_tools], ) if output_object := model_request_parameters.output_object: model_request_parameters = replace( model_request_parameters, output_object=_customize_output_object(transformer, output_object), ) return model_request_parameters ``` #### model_name ```python model_name: str ``` The model name. #### profile ```python profile: ModelProfile ``` The model profile. #### system ```python system: str ``` The system / model provider, ex: openai. Use to populate the `gen_ai.system` OpenTelemetry semantic convention attribute, so should use well-known values listed in https://opentelemetry.io/docs/specs/semconv/attributes-registry/gen-ai/#gen-ai-system when applicable. #### base_url ```python base_url: str | None ``` The base URL for the provider API, if available. ### StreamedResponse Bases: `ABC` Streamed response from an LLM when calling a tool. Source code in `pydantic_ai_slim/pydantic_ai/models/__init__.py` ```python @dataclass class StreamedResponse(ABC): """Streamed response from an LLM when calling a tool.""" _parts_manager: ModelResponsePartsManager = field(default_factory=ModelResponsePartsManager, init=False) _event_iterator: AsyncIterator[ModelResponseStreamEvent] | None = field(default=None, init=False) _usage: Usage = field(default_factory=Usage, init=False) def __aiter__(self) -> AsyncIterator[ModelResponseStreamEvent]: """Stream the response as an async iterable of [`ModelResponseStreamEvent`][pydantic_ai.messages.ModelResponseStreamEvent]s.""" if self._event_iterator is None: self._event_iterator = self._get_event_iterator() return self._event_iterator @abstractmethod async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: """Return an async iterator of [`ModelResponseStreamEvent`][pydantic_ai.messages.ModelResponseStreamEvent]s. This method should be implemented by subclasses to translate the vendor-specific stream of events into pydantic_ai-format events. It should use the `_parts_manager` to handle deltas, and should update the `_usage` attributes as it goes. """ raise NotImplementedError() # noinspection PyUnreachableCode yield def get(self) -> ModelResponse: """Build a [`ModelResponse`][pydantic_ai.messages.ModelResponse] from the data received from the stream so far.""" return ModelResponse( parts=self._parts_manager.get_parts(), model_name=self.model_name, timestamp=self.timestamp, usage=self.usage(), ) def usage(self) -> Usage: """Get the usage of the response so far. This will not be the final usage until the stream is exhausted.""" return self._usage @property @abstractmethod def model_name(self) -> str: """Get the model name of the response.""" raise NotImplementedError() @property @abstractmethod def timestamp(self) -> datetime: """Get the timestamp of the response.""" raise NotImplementedError() ``` #### __aiter__ ```python __aiter__() -> AsyncIterator[ModelResponseStreamEvent] ``` Stream the response as an async iterable of ModelResponseStreamEvents. Source code in `pydantic_ai_slim/pydantic_ai/models/__init__.py` ```python def __aiter__(self) -> AsyncIterator[ModelResponseStreamEvent]: """Stream the response as an async iterable of [`ModelResponseStreamEvent`][pydantic_ai.messages.ModelResponseStreamEvent]s.""" if self._event_iterator is None: self._event_iterator = self._get_event_iterator() return self._event_iterator ``` #### get ```python get() -> ModelResponse ``` Build a ModelResponse from the data received from the stream so far. Source code in `pydantic_ai_slim/pydantic_ai/models/__init__.py` ```python def get(self) -> ModelResponse: """Build a [`ModelResponse`][pydantic_ai.messages.ModelResponse] from the data received from the stream so far.""" return ModelResponse( parts=self._parts_manager.get_parts(), model_name=self.model_name, timestamp=self.timestamp, usage=self.usage(), ) ``` #### usage ```python usage() -> Usage ``` Get the usage of the response so far. This will not be the final usage until the stream is exhausted. Source code in `pydantic_ai_slim/pydantic_ai/models/__init__.py` ```python def usage(self) -> Usage: """Get the usage of the response so far. This will not be the final usage until the stream is exhausted.""" return self._usage ``` #### model_name ```python model_name: str ``` Get the model name of the response. #### timestamp ```python timestamp: datetime ``` Get the timestamp of the response. ### ALLOW_MODEL_REQUESTS ```python ALLOW_MODEL_REQUESTS = True ``` Whether to allow requests to models. This global setting allows you to disable request to most models, e.g. to make sure you don't accidentally make costly requests to a model during tests. The testing models TestModel and FunctionModel are no affected by this setting. ### check_allow_model_requests ```python check_allow_model_requests() -> None ``` Check if model requests are allowed. If you're defining your own models that have costs or latency associated with their use, you should call this in Model.request and Model.request_stream. Raises: | Type | Description | | --- | --- | | `RuntimeError` | If model requests are not allowed. | Source code in `pydantic_ai_slim/pydantic_ai/models/__init__.py` ```python def check_allow_model_requests() -> None: """Check if model requests are allowed. If you're defining your own models that have costs or latency associated with their use, you should call this in [`Model.request`][pydantic_ai.models.Model.request] and [`Model.request_stream`][pydantic_ai.models.Model.request_stream]. Raises: RuntimeError: If model requests are not allowed. """ if not ALLOW_MODEL_REQUESTS: raise RuntimeError('Model requests are not allowed, since ALLOW_MODEL_REQUESTS is False') ``` ### override_allow_model_requests ```python override_allow_model_requests( allow_model_requests: bool, ) -> Iterator[None] ``` Context manager to temporarily override ALLOW_MODEL_REQUESTS. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `allow_model_requests` | `bool` | Whether to allow model requests within the context. | *required* | Source code in `pydantic_ai_slim/pydantic_ai/models/__init__.py` ```python @contextmanager def override_allow_model_requests(allow_model_requests: bool) -> Iterator[None]: """Context manager to temporarily override [`ALLOW_MODEL_REQUESTS`][pydantic_ai.models.ALLOW_MODEL_REQUESTS]. Args: allow_model_requests: Whether to allow model requests within the context. """ global ALLOW_MODEL_REQUESTS old_value = ALLOW_MODEL_REQUESTS ALLOW_MODEL_REQUESTS = allow_model_requests # pyright: ignore[reportConstantRedefinition] try: yield finally: ALLOW_MODEL_REQUESTS = old_value # pyright: ignore[reportConstantRedefinition] ``` # `pydantic_ai.models.bedrock` ## Setup For details on how to set up authentication with this model, see [model configuration for Bedrock](../../../models/bedrock/). ### LatestBedrockModelNames ```python LatestBedrockModelNames = Literal[ "amazon.titan-tg1-large", "amazon.titan-text-lite-v1", "amazon.titan-text-express-v1", "us.amazon.nova-pro-v1:0", "us.amazon.nova-lite-v1:0", "us.amazon.nova-micro-v1:0", "anthropic.claude-3-5-sonnet-20241022-v2:0", "us.anthropic.claude-3-5-sonnet-20241022-v2:0", "anthropic.claude-3-5-haiku-20241022-v1:0", "us.anthropic.claude-3-5-haiku-20241022-v1:0", "anthropic.claude-instant-v1", "anthropic.claude-v2:1", "anthropic.claude-v2", "anthropic.claude-3-sonnet-20240229-v1:0", "us.anthropic.claude-3-sonnet-20240229-v1:0", "anthropic.claude-3-haiku-20240307-v1:0", "us.anthropic.claude-3-haiku-20240307-v1:0", "anthropic.claude-3-opus-20240229-v1:0", "us.anthropic.claude-3-opus-20240229-v1:0", "anthropic.claude-3-5-sonnet-20240620-v1:0", "us.anthropic.claude-3-5-sonnet-20240620-v1:0", "anthropic.claude-3-7-sonnet-20250219-v1:0", "us.anthropic.claude-3-7-sonnet-20250219-v1:0", "anthropic.claude-opus-4-20250514-v1:0", "us.anthropic.claude-opus-4-20250514-v1:0", "anthropic.claude-sonnet-4-20250514-v1:0", "us.anthropic.claude-sonnet-4-20250514-v1:0", "cohere.command-text-v14", "cohere.command-r-v1:0", "cohere.command-r-plus-v1:0", "cohere.command-light-text-v14", "meta.llama3-8b-instruct-v1:0", "meta.llama3-70b-instruct-v1:0", "meta.llama3-1-8b-instruct-v1:0", "us.meta.llama3-1-8b-instruct-v1:0", "meta.llama3-1-70b-instruct-v1:0", "us.meta.llama3-1-70b-instruct-v1:0", "meta.llama3-1-405b-instruct-v1:0", "us.meta.llama3-2-11b-instruct-v1:0", "us.meta.llama3-2-90b-instruct-v1:0", "us.meta.llama3-2-1b-instruct-v1:0", "us.meta.llama3-2-3b-instruct-v1:0", "us.meta.llama3-3-70b-instruct-v1:0", "mistral.mistral-7b-instruct-v0:2", "mistral.mixtral-8x7b-instruct-v0:1", "mistral.mistral-large-2402-v1:0", "mistral.mistral-large-2407-v1:0", ] ``` Latest Bedrock models. ### BedrockModelName ```python BedrockModelName = Union[str, LatestBedrockModelNames] ``` Possible Bedrock model names. Since Bedrock supports a variety of date-stamped models, we explicitly list the latest models but allow any name in the type hints. See [the Bedrock docs](https://docs.aws.amazon.com/bedrock/latest/userguide/models-supported.html) for a full list. ### BedrockModelSettings Bases: `ModelSettings` Settings for Bedrock models. See [the Bedrock Converse API docs](https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html#API_runtime_Converse_RequestSyntax) for a full list. See [the boto3 implementation](https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/bedrock-runtime/client/converse.html) of the Bedrock Converse API. Source code in `pydantic_ai_slim/pydantic_ai/models/bedrock.py` ```python class BedrockModelSettings(ModelSettings, total=False): """Settings for Bedrock models. See [the Bedrock Converse API docs](https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html#API_runtime_Converse_RequestSyntax) for a full list. See [the boto3 implementation](https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/bedrock-runtime/client/converse.html) of the Bedrock Converse API. """ # ALL FIELDS MUST BE `bedrock_` PREFIXED SO YOU CAN MERGE THEM WITH OTHER MODELS. bedrock_guardrail_config: GuardrailConfigurationTypeDef """Content moderation and safety settings for Bedrock API requests. See more about it on . """ bedrock_performance_configuration: PerformanceConfigurationTypeDef """Performance optimization settings for model inference. See more about it on . """ bedrock_request_metadata: dict[str, str] """Additional metadata to attach to Bedrock API requests. See more about it on . """ bedrock_additional_model_response_fields_paths: list[str] """JSON paths to extract additional fields from model responses. See more about it on . """ bedrock_prompt_variables: Mapping[str, PromptVariableValuesTypeDef] """Variables for substitution into prompt templates. See more about it on . """ bedrock_additional_model_requests_fields: Mapping[str, Any] """Additional model-specific parameters to include in requests. See more about it on . """ ``` #### bedrock_guardrail_config ```python bedrock_guardrail_config: GuardrailConfigurationTypeDef ``` Content moderation and safety settings for Bedrock API requests. See more about it on . #### bedrock_performance_configuration ```python bedrock_performance_configuration: ( PerformanceConfigurationTypeDef ) ``` Performance optimization settings for model inference. See more about it on . #### bedrock_request_metadata ```python bedrock_request_metadata: dict[str, str] ``` Additional metadata to attach to Bedrock API requests. See more about it on . #### bedrock_additional_model_response_fields_paths ```python bedrock_additional_model_response_fields_paths: list[str] ``` JSON paths to extract additional fields from model responses. See more about it on . #### bedrock_prompt_variables ```python bedrock_prompt_variables: Mapping[ str, PromptVariableValuesTypeDef ] ``` Variables for substitution into prompt templates. See more about it on . #### bedrock_additional_model_requests_fields ```python bedrock_additional_model_requests_fields: Mapping[str, Any] ``` Additional model-specific parameters to include in requests. See more about it on . ### BedrockConverseModel Bases: `Model` A model that uses the Bedrock Converse API. Source code in `pydantic_ai_slim/pydantic_ai/models/bedrock.py` ```python @dataclass(init=False) class BedrockConverseModel(Model): """A model that uses the Bedrock Converse API.""" client: BedrockRuntimeClient _model_name: BedrockModelName = field(repr=False) _system: str = field(default='bedrock', repr=False) @property def model_name(self) -> str: """The model name.""" return self._model_name @property def system(self) -> str: """The system / model provider, ex: openai.""" return self._system def __init__( self, model_name: BedrockModelName, *, provider: Literal['bedrock'] | Provider[BaseClient] = 'bedrock', profile: ModelProfileSpec | None = None, ): """Initialize a Bedrock model. Args: model_name: The name of the model to use. model_name: The name of the Bedrock model to use. List of model names available [here](https://docs.aws.amazon.com/bedrock/latest/userguide/models-supported.html). provider: The provider to use for authentication and API access. Can be either the string 'bedrock' or an instance of `Provider[BaseClient]`. If not provided, a new provider will be created using the other parameters. profile: The model profile to use. Defaults to a profile picked by the provider based on the model name. """ self._model_name = model_name if isinstance(provider, str): provider = infer_provider(provider) self.client = cast('BedrockRuntimeClient', provider.client) self._profile = profile or provider.model_profile def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[ToolTypeDef]: tools = [self._map_tool_definition(r) for r in model_request_parameters.function_tools] if model_request_parameters.output_tools: tools += [self._map_tool_definition(r) for r in model_request_parameters.output_tools] return tools @staticmethod def _map_tool_definition(f: ToolDefinition) -> ToolTypeDef: return { 'toolSpec': { 'name': f.name, 'description': f.description, 'inputSchema': {'json': f.parameters_json_schema}, } } @property def base_url(self) -> str: return str(self.client.meta.endpoint_url) async def request( self, messages: list[ModelMessage], model_settings: ModelSettings | None, model_request_parameters: ModelRequestParameters, ) -> ModelResponse: settings = cast(BedrockModelSettings, model_settings or {}) response = await self._messages_create(messages, False, settings, model_request_parameters) model_response = await self._process_response(response) model_response.usage.requests = 1 return model_response @asynccontextmanager async def request_stream( self, messages: list[ModelMessage], model_settings: ModelSettings | None, model_request_parameters: ModelRequestParameters, ) -> AsyncIterator[StreamedResponse]: settings = cast(BedrockModelSettings, model_settings or {}) response = await self._messages_create(messages, True, settings, model_request_parameters) yield BedrockStreamedResponse(_model_name=self.model_name, _event_stream=response) async def _process_response(self, response: ConverseResponseTypeDef) -> ModelResponse: items: list[ModelResponsePart] = [] if message := response['output'].get('message'): # pragma: no branch for item in message['content']: if reasoning_content := item.get('reasoningContent'): reasoning_text = reasoning_content.get('reasoningText') if reasoning_text: # pragma: no branch thinking_part = ThinkingPart(content=reasoning_text['text']) if reasoning_signature := reasoning_text.get('signature'): thinking_part.signature = reasoning_signature items.append(thinking_part) if text := item.get('text'): items.append(TextPart(content=text)) elif tool_use := item.get('toolUse'): items.append( ToolCallPart( tool_name=tool_use['name'], args=tool_use['input'], tool_call_id=tool_use['toolUseId'], ), ) u = usage.Usage( request_tokens=response['usage']['inputTokens'], response_tokens=response['usage']['outputTokens'], total_tokens=response['usage']['totalTokens'], ) vendor_id = response.get('ResponseMetadata', {}).get('RequestId', None) return ModelResponse(items, usage=u, model_name=self.model_name, vendor_id=vendor_id) @overload async def _messages_create( self, messages: list[ModelMessage], stream: Literal[True], model_settings: BedrockModelSettings | None, model_request_parameters: ModelRequestParameters, ) -> EventStream[ConverseStreamOutputTypeDef]: pass @overload async def _messages_create( self, messages: list[ModelMessage], stream: Literal[False], model_settings: BedrockModelSettings | None, model_request_parameters: ModelRequestParameters, ) -> ConverseResponseTypeDef: pass async def _messages_create( self, messages: list[ModelMessage], stream: bool, model_settings: BedrockModelSettings | None, model_request_parameters: ModelRequestParameters, ) -> ConverseResponseTypeDef | EventStream[ConverseStreamOutputTypeDef]: system_prompt, bedrock_messages = await self._map_messages(messages) inference_config = self._map_inference_config(model_settings) params: ConverseRequestTypeDef = { 'modelId': self.model_name, 'messages': bedrock_messages, 'system': system_prompt, 'inferenceConfig': inference_config, } tool_config = self._map_tool_config(model_request_parameters) if tool_config: params['toolConfig'] = tool_config # Bedrock supports a set of specific extra parameters if model_settings: if guardrail_config := model_settings.get('bedrock_guardrail_config', None): params['guardrailConfig'] = guardrail_config if performance_configuration := model_settings.get('bedrock_performance_configuration', None): params['performanceConfig'] = performance_configuration if request_metadata := model_settings.get('bedrock_request_metadata', None): params['requestMetadata'] = request_metadata if additional_model_response_fields_paths := model_settings.get( 'bedrock_additional_model_response_fields_paths', None ): params['additionalModelResponseFieldPaths'] = additional_model_response_fields_paths if additional_model_requests_fields := model_settings.get('bedrock_additional_model_requests_fields', None): params['additionalModelRequestFields'] = additional_model_requests_fields if prompt_variables := model_settings.get('bedrock_prompt_variables', None): params['promptVariables'] = prompt_variables if stream: model_response = await anyio.to_thread.run_sync(functools.partial(self.client.converse_stream, **params)) model_response = model_response['stream'] else: model_response = await anyio.to_thread.run_sync(functools.partial(self.client.converse, **params)) return model_response @staticmethod def _map_inference_config( model_settings: ModelSettings | None, ) -> InferenceConfigurationTypeDef: model_settings = model_settings or {} inference_config: InferenceConfigurationTypeDef = {} if max_tokens := model_settings.get('max_tokens'): inference_config['maxTokens'] = max_tokens if (temperature := model_settings.get('temperature')) is not None: inference_config['temperature'] = temperature if top_p := model_settings.get('top_p'): inference_config['topP'] = top_p if stop_sequences := model_settings.get('stop_sequences'): inference_config['stopSequences'] = stop_sequences return inference_config def _map_tool_config(self, model_request_parameters: ModelRequestParameters) -> ToolConfigurationTypeDef | None: tools = self._get_tools(model_request_parameters) if not tools: return None tool_choice: ToolChoiceTypeDef if not model_request_parameters.allow_text_output: tool_choice = {'any': {}} else: tool_choice = {'auto': {}} tool_config: ToolConfigurationTypeDef = {'tools': tools} if tool_choice and BedrockModelProfile.from_profile(self.profile).bedrock_supports_tool_choice: tool_config['toolChoice'] = tool_choice return tool_config async def _map_messages( # noqa: C901 self, messages: list[ModelMessage] ) -> tuple[list[SystemContentBlockTypeDef], list[MessageUnionTypeDef]]: """Maps a `pydantic_ai.Message` to the Bedrock `MessageUnionTypeDef`. Groups consecutive ToolReturnPart objects into a single user message as required by Bedrock Claude/Nova models. """ profile = BedrockModelProfile.from_profile(self.profile) system_prompt: list[SystemContentBlockTypeDef] = [] bedrock_messages: list[MessageUnionTypeDef] = [] document_count: Iterator[int] = count(1) for message in messages: if isinstance(message, ModelRequest): for part in message.parts: if isinstance(part, SystemPromptPart): system_prompt.append({'text': part.content}) elif isinstance(part, UserPromptPart): bedrock_messages.extend(await self._map_user_prompt(part, document_count)) elif isinstance(part, ToolReturnPart): assert part.tool_call_id is not None bedrock_messages.append( { 'role': 'user', 'content': [ { 'toolResult': { 'toolUseId': part.tool_call_id, 'content': [ {'text': part.model_response_str()} if profile.bedrock_tool_result_format == 'text' else {'json': part.model_response_object()} ], 'status': 'success', } } ], } ) elif isinstance(part, RetryPromptPart): # TODO(Marcelo): We need to add a test here. if part.tool_name is None: # pragma: no cover bedrock_messages.append({'role': 'user', 'content': [{'text': part.model_response()}]}) else: assert part.tool_call_id is not None bedrock_messages.append( { 'role': 'user', 'content': [ { 'toolResult': { 'toolUseId': part.tool_call_id, 'content': [{'text': part.model_response()}], 'status': 'error', } } ], } ) elif isinstance(message, ModelResponse): content: list[ContentBlockOutputTypeDef] = [] for item in message.parts: if isinstance(item, TextPart): content.append({'text': item.content}) elif isinstance(item, ThinkingPart): # NOTE: We don't pass the thinking part to Bedrock since it raises an error. pass else: assert isinstance(item, ToolCallPart) content.append(self._map_tool_call(item)) bedrock_messages.append({'role': 'assistant', 'content': content}) else: assert_never(message) # Merge together sequential user messages. processed_messages: list[MessageUnionTypeDef] = [] last_message: dict[str, Any] | None = None for current_message in bedrock_messages: if ( last_message is not None and current_message['role'] == last_message['role'] and current_message['role'] == 'user' ): # Add the new user content onto the existing user message. last_content = list(last_message['content']) last_content.extend(current_message['content']) last_message['content'] = last_content continue # Add the entire message to the list of messages. processed_messages.append(current_message) last_message = cast(dict[str, Any], current_message) if instructions := self._get_instructions(messages): system_prompt.insert(0, {'text': instructions}) return system_prompt, processed_messages @staticmethod async def _map_user_prompt(part: UserPromptPart, document_count: Iterator[int]) -> list[MessageUnionTypeDef]: content: list[ContentBlockUnionTypeDef] = [] if isinstance(part.content, str): content.append({'text': part.content}) else: for item in part.content: if isinstance(item, str): content.append({'text': item}) elif isinstance(item, BinaryContent): format = item.format if item.is_document: name = f'Document {next(document_count)}' assert format in ('pdf', 'txt', 'csv', 'doc', 'docx', 'xls', 'xlsx', 'html', 'md') content.append({'document': {'name': name, 'format': format, 'source': {'bytes': item.data}}}) elif item.is_image: assert format in ('jpeg', 'png', 'gif', 'webp') content.append({'image': {'format': format, 'source': {'bytes': item.data}}}) elif item.is_video: assert format in ('mkv', 'mov', 'mp4', 'webm', 'flv', 'mpeg', 'mpg', 'wmv', 'three_gp') content.append({'video': {'format': format, 'source': {'bytes': item.data}}}) else: raise NotImplementedError('Binary content is not supported yet.') elif isinstance(item, (ImageUrl, DocumentUrl, VideoUrl)): downloaded_item = await download_item(item, data_format='bytes', type_format='extension') format = downloaded_item['data_type'] if item.kind == 'image-url': format = item.media_type.split('/')[1] assert format in ('jpeg', 'png', 'gif', 'webp'), f'Unsupported image format: {format}' image: ImageBlockTypeDef = {'format': format, 'source': {'bytes': downloaded_item['data']}} content.append({'image': image}) elif item.kind == 'document-url': name = f'Document {next(document_count)}' document: DocumentBlockTypeDef = { 'name': name, 'format': item.format, 'source': {'bytes': downloaded_item['data']}, } content.append({'document': document}) elif item.kind == 'video-url': # pragma: no branch format = item.media_type.split('/')[1] assert format in ( 'mkv', 'mov', 'mp4', 'webm', 'flv', 'mpeg', 'mpg', 'wmv', 'three_gp', ), f'Unsupported video format: {format}' video: VideoBlockTypeDef = {'format': format, 'source': {'bytes': downloaded_item['data']}} content.append({'video': video}) elif isinstance(item, AudioUrl): # pragma: no cover raise NotImplementedError('Audio is not supported yet.') else: assert_never(item) return [{'role': 'user', 'content': content}] @staticmethod def _map_tool_call(t: ToolCallPart) -> ContentBlockOutputTypeDef: return { 'toolUse': {'toolUseId': _utils.guard_tool_call_id(t=t), 'name': t.tool_name, 'input': t.args_as_dict()} } ``` #### model_name ```python model_name: str ``` The model name. #### system ```python system: str ``` The system / model provider, ex: openai. #### __init__ ```python __init__( model_name: BedrockModelName, *, provider: ( Literal["bedrock"] | Provider[BaseClient] ) = "bedrock", profile: ModelProfileSpec | None = None ) ``` Initialize a Bedrock model. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `model_name` | `BedrockModelName` | The name of the model to use. | *required* | | `model_name` | `BedrockModelName` | The name of the Bedrock model to use. List of model names available here. | *required* | | `provider` | `Literal['bedrock'] | Provider[BaseClient]` | The provider to use for authentication and API access. Can be either the string 'bedrock' or an instance of Provider[BaseClient]. If not provided, a new provider will be created using the other parameters. | `'bedrock'` | | `profile` | `ModelProfileSpec | None` | The model profile to use. Defaults to a profile picked by the provider based on the model name. | `None` | Source code in `pydantic_ai_slim/pydantic_ai/models/bedrock.py` ```python def __init__( self, model_name: BedrockModelName, *, provider: Literal['bedrock'] | Provider[BaseClient] = 'bedrock', profile: ModelProfileSpec | None = None, ): """Initialize a Bedrock model. Args: model_name: The name of the model to use. model_name: The name of the Bedrock model to use. List of model names available [here](https://docs.aws.amazon.com/bedrock/latest/userguide/models-supported.html). provider: The provider to use for authentication and API access. Can be either the string 'bedrock' or an instance of `Provider[BaseClient]`. If not provided, a new provider will be created using the other parameters. profile: The model profile to use. Defaults to a profile picked by the provider based on the model name. """ self._model_name = model_name if isinstance(provider, str): provider = infer_provider(provider) self.client = cast('BedrockRuntimeClient', provider.client) self._profile = profile or provider.model_profile ``` ### BedrockStreamedResponse Bases: `StreamedResponse` Implementation of `StreamedResponse` for Bedrock models. Source code in `pydantic_ai_slim/pydantic_ai/models/bedrock.py` ```python @dataclass class BedrockStreamedResponse(StreamedResponse): """Implementation of `StreamedResponse` for Bedrock models.""" _model_name: BedrockModelName _event_stream: EventStream[ConverseStreamOutputTypeDef] _timestamp: datetime = field(default_factory=_utils.now_utc) async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: """Return an async iterator of [`ModelResponseStreamEvent`][pydantic_ai.messages.ModelResponseStreamEvent]s. This method should be implemented by subclasses to translate the vendor-specific stream of events into pydantic_ai-format events. """ chunk: ConverseStreamOutputTypeDef tool_id: str | None = None async for chunk in _AsyncIteratorWrapper(self._event_stream): # TODO(Marcelo): Switch this to `match` when we drop Python 3.9 support. if 'messageStart' in chunk: continue if 'messageStop' in chunk: continue if 'metadata' in chunk: if 'usage' in chunk['metadata']: # pragma: no branch self._usage += self._map_usage(chunk['metadata']) continue if 'contentBlockStart' in chunk: index = chunk['contentBlockStart']['contentBlockIndex'] start = chunk['contentBlockStart']['start'] if 'toolUse' in start: # pragma: no branch tool_use_start = start['toolUse'] tool_id = tool_use_start['toolUseId'] tool_name = tool_use_start['name'] maybe_event = self._parts_manager.handle_tool_call_delta( vendor_part_id=index, tool_name=tool_name, args=None, tool_call_id=tool_id, ) if maybe_event: # pragma: no branch yield maybe_event if 'contentBlockDelta' in chunk: index = chunk['contentBlockDelta']['contentBlockIndex'] delta = chunk['contentBlockDelta']['delta'] if 'reasoningContent' in delta: if text := delta['reasoningContent'].get('text'): yield self._parts_manager.handle_thinking_delta(vendor_part_id=index, content=text) else: # pragma: no cover warnings.warn( f'Only text reasoning content is supported yet, but you got {delta["reasoningContent"]}. ' 'Please report this to the maintainers.', UserWarning, ) if 'text' in delta: yield self._parts_manager.handle_text_delta(vendor_part_id=index, content=delta['text']) if 'toolUse' in delta: tool_use = delta['toolUse'] maybe_event = self._parts_manager.handle_tool_call_delta( vendor_part_id=index, tool_name=tool_use.get('name'), args=tool_use.get('input'), tool_call_id=tool_id, ) if maybe_event: # pragma: no branch yield maybe_event @property def timestamp(self) -> datetime: return self._timestamp @property def model_name(self) -> str: """Get the model name of the response.""" return self._model_name def _map_usage(self, metadata: ConverseStreamMetadataEventTypeDef) -> usage.Usage: return usage.Usage( request_tokens=metadata['usage']['inputTokens'], response_tokens=metadata['usage']['outputTokens'], total_tokens=metadata['usage']['totalTokens'], ) ``` #### model_name ```python model_name: str ``` Get the model name of the response. # `pydantic_ai.models.cohere` ## Setup For details on how to set up authentication with this model, see [model configuration for Cohere](../../../models/cohere/). ### LatestCohereModelNames ```python LatestCohereModelNames = Literal[ "c4ai-aya-expanse-32b", "c4ai-aya-expanse-8b", "command", "command-light", "command-light-nightly", "command-nightly", "command-r", "command-r-03-2024", "command-r-08-2024", "command-r-plus", "command-r-plus-04-2024", "command-r-plus-08-2024", "command-r7b-12-2024", ] ``` Latest Cohere models. ### CohereModelName ```python CohereModelName = Union[str, LatestCohereModelNames] ``` Possible Cohere model names. Since Cohere supports a variety of date-stamped models, we explicitly list the latest models but allow any name in the type hints. See [Cohere's docs](https://docs.cohere.com/v2/docs/models) for a list of all available models. ### CohereModelSettings Bases: `ModelSettings` Settings used for a Cohere model request. Source code in `pydantic_ai_slim/pydantic_ai/models/cohere.py` ```python class CohereModelSettings(ModelSettings, total=False): """Settings used for a Cohere model request.""" ``` ### CohereModel Bases: `Model` A model that uses the Cohere API. Internally, this uses the [Cohere Python client](https://github.com/cohere-ai/cohere-python) to interact with the API. Apart from `__init__`, all methods are private or match those of the base class. Source code in `pydantic_ai_slim/pydantic_ai/models/cohere.py` ```python @dataclass(init=False) class CohereModel(Model): """A model that uses the Cohere API. Internally, this uses the [Cohere Python client]( https://github.com/cohere-ai/cohere-python) to interact with the API. Apart from `__init__`, all methods are private or match those of the base class. """ client: AsyncClientV2 = field(repr=False) _model_name: CohereModelName = field(repr=False) _system: str = field(default='cohere', repr=False) def __init__( self, model_name: CohereModelName, *, provider: Literal['cohere'] | Provider[AsyncClientV2] = 'cohere', profile: ModelProfileSpec | None = None, ): """Initialize an Cohere model. Args: model_name: The name of the Cohere model to use. List of model names available [here](https://docs.cohere.com/docs/models#command). provider: The provider to use for authentication and API access. Can be either the string 'cohere' or an instance of `Provider[AsyncClientV2]`. If not provided, a new provider will be created using the other parameters. profile: The model profile to use. Defaults to a profile picked by the provider based on the model name. """ self._model_name = model_name if isinstance(provider, str): provider = infer_provider(provider) self.client = provider.client self._profile = profile or provider.model_profile @property def base_url(self) -> str: client_wrapper = self.client._client_wrapper # type: ignore return str(client_wrapper.get_base_url()) async def request( self, messages: list[ModelMessage], model_settings: ModelSettings | None, model_request_parameters: ModelRequestParameters, ) -> ModelResponse: check_allow_model_requests() response = await self._chat(messages, cast(CohereModelSettings, model_settings or {}), model_request_parameters) model_response = self._process_response(response) model_response.usage.requests = 1 return model_response @property def model_name(self) -> CohereModelName: """The model name.""" return self._model_name @property def system(self) -> str: """The system / model provider.""" return self._system async def _chat( self, messages: list[ModelMessage], model_settings: CohereModelSettings, model_request_parameters: ModelRequestParameters, ) -> ChatResponse: tools = self._get_tools(model_request_parameters) cohere_messages = self._map_messages(messages) try: return await self.client.chat( model=self._model_name, messages=cohere_messages, tools=tools or OMIT, max_tokens=model_settings.get('max_tokens', OMIT), stop_sequences=model_settings.get('stop_sequences', OMIT), temperature=model_settings.get('temperature', OMIT), p=model_settings.get('top_p', OMIT), seed=model_settings.get('seed', OMIT), presence_penalty=model_settings.get('presence_penalty', OMIT), frequency_penalty=model_settings.get('frequency_penalty', OMIT), ) except ApiError as e: if (status_code := e.status_code) and status_code >= 400: raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=e.body) from e raise # pragma: lax no cover def _process_response(self, response: ChatResponse) -> ModelResponse: """Process a non-streamed response, and prepare a message to return.""" parts: list[ModelResponsePart] = [] if response.message.content is not None and len(response.message.content) > 0: # While Cohere's API returns a list, it only does that for future proofing # and currently only one item is being returned. choice = response.message.content[0] parts.extend(split_content_into_text_and_thinking(choice.text)) for c in response.message.tool_calls or []: if c.function and c.function.name and c.function.arguments: # pragma: no branch parts.append( ToolCallPart( tool_name=c.function.name, args=c.function.arguments, tool_call_id=c.id or _generate_tool_call_id(), ) ) return ModelResponse(parts=parts, usage=_map_usage(response), model_name=self._model_name) def _map_messages(self, messages: list[ModelMessage]) -> list[ChatMessageV2]: """Just maps a `pydantic_ai.Message` to a `cohere.ChatMessageV2`.""" cohere_messages: list[ChatMessageV2] = [] for message in messages: if isinstance(message, ModelRequest): cohere_messages.extend(self._map_user_message(message)) elif isinstance(message, ModelResponse): texts: list[str] = [] tool_calls: list[ToolCallV2] = [] for item in message.parts: if isinstance(item, TextPart): texts.append(item.content) elif isinstance(item, ThinkingPart): # NOTE: We don't send ThinkingPart to the providers yet. If you are unsatisfied with this, # please open an issue. The below code is the code to send thinking to the provider. # texts.append(f'\n{item.content}\n') pass elif isinstance(item, ToolCallPart): tool_calls.append(self._map_tool_call(item)) else: assert_never(item) message_param = AssistantChatMessageV2(role='assistant') if texts: message_param.content = [TextAssistantMessageContentItem(text='\n\n'.join(texts))] if tool_calls: message_param.tool_calls = tool_calls cohere_messages.append(message_param) else: assert_never(message) if instructions := self._get_instructions(messages): cohere_messages.insert(0, SystemChatMessageV2(role='system', content=instructions)) return cohere_messages def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[ToolV2]: tools = [self._map_tool_definition(r) for r in model_request_parameters.function_tools] if model_request_parameters.output_tools: tools += [self._map_tool_definition(r) for r in model_request_parameters.output_tools] return tools @staticmethod def _map_tool_call(t: ToolCallPart) -> ToolCallV2: return ToolCallV2( id=_guard_tool_call_id(t=t), type='function', function=ToolCallV2Function( name=t.tool_name, arguments=t.args_as_json_str(), ), ) @staticmethod def _map_tool_definition(f: ToolDefinition) -> ToolV2: return ToolV2( type='function', function=ToolV2Function( name=f.name, description=f.description, parameters=f.parameters_json_schema, ), ) @classmethod def _map_user_message(cls, message: ModelRequest) -> Iterable[ChatMessageV2]: for part in message.parts: if isinstance(part, SystemPromptPart): yield SystemChatMessageV2(role='system', content=part.content) elif isinstance(part, UserPromptPart): if isinstance(part.content, str): yield UserChatMessageV2(role='user', content=part.content) else: raise RuntimeError('Cohere does not yet support multi-modal inputs.') elif isinstance(part, ToolReturnPart): yield ToolChatMessageV2( role='tool', tool_call_id=_guard_tool_call_id(t=part), content=part.model_response_str(), ) elif isinstance(part, RetryPromptPart): if part.tool_name is None: yield UserChatMessageV2(role='user', content=part.model_response()) # pragma: no cover else: yield ToolChatMessageV2( role='tool', tool_call_id=_guard_tool_call_id(t=part), content=part.model_response(), ) else: assert_never(part) ``` #### __init__ ```python __init__( model_name: CohereModelName, *, provider: ( Literal["cohere"] | Provider[AsyncClientV2] ) = "cohere", profile: ModelProfileSpec | None = None ) ``` Initialize an Cohere model. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `model_name` | `CohereModelName` | The name of the Cohere model to use. List of model names available here. | *required* | | `provider` | `Literal['cohere'] | Provider[AsyncClientV2]` | The provider to use for authentication and API access. Can be either the string 'cohere' or an instance of Provider[AsyncClientV2]. If not provided, a new provider will be created using the other parameters. | `'cohere'` | | `profile` | `ModelProfileSpec | None` | The model profile to use. Defaults to a profile picked by the provider based on the model name. | `None` | Source code in `pydantic_ai_slim/pydantic_ai/models/cohere.py` ```python def __init__( self, model_name: CohereModelName, *, provider: Literal['cohere'] | Provider[AsyncClientV2] = 'cohere', profile: ModelProfileSpec | None = None, ): """Initialize an Cohere model. Args: model_name: The name of the Cohere model to use. List of model names available [here](https://docs.cohere.com/docs/models#command). provider: The provider to use for authentication and API access. Can be either the string 'cohere' or an instance of `Provider[AsyncClientV2]`. If not provided, a new provider will be created using the other parameters. profile: The model profile to use. Defaults to a profile picked by the provider based on the model name. """ self._model_name = model_name if isinstance(provider, str): provider = infer_provider(provider) self.client = provider.client self._profile = profile or provider.model_profile ``` #### model_name ```python model_name: CohereModelName ``` The model name. #### system ```python system: str ``` The system / model provider. # pydantic_ai.models.fallback ### FallbackModel Bases: `Model` A model that uses one or more fallback models upon failure. Apart from `__init__`, all methods are private or match those of the base class. Source code in `pydantic_ai_slim/pydantic_ai/models/fallback.py` ```python @dataclass(init=False) class FallbackModel(Model): """A model that uses one or more fallback models upon failure. Apart from `__init__`, all methods are private or match those of the base class. """ models: list[Model] _model_name: str = field(repr=False) _fallback_on: Callable[[Exception], bool] def __init__( self, default_model: Model | KnownModelName, *fallback_models: Model | KnownModelName, fallback_on: Callable[[Exception], bool] | tuple[type[Exception], ...] = (ModelHTTPError,), ): """Initialize a fallback model instance. Args: default_model: The name or instance of the default model to use. fallback_models: The names or instances of the fallback models to use upon failure. fallback_on: A callable or tuple of exceptions that should trigger a fallback. """ self.models = [infer_model(default_model), *[infer_model(m) for m in fallback_models]] if isinstance(fallback_on, tuple): self._fallback_on = _default_fallback_condition_factory(fallback_on) else: self._fallback_on = fallback_on async def request( self, messages: list[ModelMessage], model_settings: ModelSettings | None, model_request_parameters: ModelRequestParameters, ) -> ModelResponse: """Try each model in sequence until one succeeds. In case of failure, raise a FallbackExceptionGroup with all exceptions. """ exceptions: list[Exception] = [] for model in self.models: customized_model_request_parameters = model.customize_request_parameters(model_request_parameters) try: response = await model.request(messages, model_settings, customized_model_request_parameters) except Exception as exc: if self._fallback_on(exc): exceptions.append(exc) continue raise exc self._set_span_attributes(model) return response raise FallbackExceptionGroup('All models from FallbackModel failed', exceptions) @asynccontextmanager async def request_stream( self, messages: list[ModelMessage], model_settings: ModelSettings | None, model_request_parameters: ModelRequestParameters, ) -> AsyncIterator[StreamedResponse]: """Try each model in sequence until one succeeds.""" exceptions: list[Exception] = [] for model in self.models: customized_model_request_parameters = model.customize_request_parameters(model_request_parameters) async with AsyncExitStack() as stack: try: response = await stack.enter_async_context( model.request_stream(messages, model_settings, customized_model_request_parameters) ) except Exception as exc: if self._fallback_on(exc): exceptions.append(exc) continue raise exc # pragma: no cover self._set_span_attributes(model) yield response return raise FallbackExceptionGroup('All models from FallbackModel failed', exceptions) def _set_span_attributes(self, model: Model): with suppress(Exception): span = get_current_span() if span.is_recording(): attributes = getattr(span, 'attributes', {}) if attributes.get('gen_ai.request.model') == self.model_name: # pragma: no branch span.set_attributes(InstrumentedModel.model_attributes(model)) @property def model_name(self) -> str: """The model name.""" return f'fallback:{",".join(model.model_name for model in self.models)}' @property def system(self) -> str: return f'fallback:{",".join(model.system for model in self.models)}' @property def base_url(self) -> str | None: return self.models[0].base_url ``` #### __init__ ```python __init__( default_model: Model | KnownModelName, *fallback_models: Model | KnownModelName, fallback_on: ( Callable[[Exception], bool] | tuple[type[Exception], ...] ) = (ModelHTTPError,) ) ``` Initialize a fallback model instance. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `default_model` | `Model | KnownModelName` | The name or instance of the default model to use. | *required* | | `fallback_models` | `Model | KnownModelName` | The names or instances of the fallback models to use upon failure. | `()` | | `fallback_on` | `Callable[[Exception], bool] | tuple[type[Exception], ...]` | A callable or tuple of exceptions that should trigger a fallback. | `(ModelHTTPError,)` | Source code in `pydantic_ai_slim/pydantic_ai/models/fallback.py` ```python def __init__( self, default_model: Model | KnownModelName, *fallback_models: Model | KnownModelName, fallback_on: Callable[[Exception], bool] | tuple[type[Exception], ...] = (ModelHTTPError,), ): """Initialize a fallback model instance. Args: default_model: The name or instance of the default model to use. fallback_models: The names or instances of the fallback models to use upon failure. fallback_on: A callable or tuple of exceptions that should trigger a fallback. """ self.models = [infer_model(default_model), *[infer_model(m) for m in fallback_models]] if isinstance(fallback_on, tuple): self._fallback_on = _default_fallback_condition_factory(fallback_on) else: self._fallback_on = fallback_on ``` #### request ```python request( messages: list[ModelMessage], model_settings: ModelSettings | None, model_request_parameters: ModelRequestParameters, ) -> ModelResponse ``` Try each model in sequence until one succeeds. In case of failure, raise a FallbackExceptionGroup with all exceptions. Source code in `pydantic_ai_slim/pydantic_ai/models/fallback.py` ```python async def request( self, messages: list[ModelMessage], model_settings: ModelSettings | None, model_request_parameters: ModelRequestParameters, ) -> ModelResponse: """Try each model in sequence until one succeeds. In case of failure, raise a FallbackExceptionGroup with all exceptions. """ exceptions: list[Exception] = [] for model in self.models: customized_model_request_parameters = model.customize_request_parameters(model_request_parameters) try: response = await model.request(messages, model_settings, customized_model_request_parameters) except Exception as exc: if self._fallback_on(exc): exceptions.append(exc) continue raise exc self._set_span_attributes(model) return response raise FallbackExceptionGroup('All models from FallbackModel failed', exceptions) ``` #### request_stream ```python request_stream( messages: list[ModelMessage], model_settings: ModelSettings | None, model_request_parameters: ModelRequestParameters, ) -> AsyncIterator[StreamedResponse] ``` Try each model in sequence until one succeeds. Source code in `pydantic_ai_slim/pydantic_ai/models/fallback.py` ```python @asynccontextmanager async def request_stream( self, messages: list[ModelMessage], model_settings: ModelSettings | None, model_request_parameters: ModelRequestParameters, ) -> AsyncIterator[StreamedResponse]: """Try each model in sequence until one succeeds.""" exceptions: list[Exception] = [] for model in self.models: customized_model_request_parameters = model.customize_request_parameters(model_request_parameters) async with AsyncExitStack() as stack: try: response = await stack.enter_async_context( model.request_stream(messages, model_settings, customized_model_request_parameters) ) except Exception as exc: if self._fallback_on(exc): exceptions.append(exc) continue raise exc # pragma: no cover self._set_span_attributes(model) yield response return raise FallbackExceptionGroup('All models from FallbackModel failed', exceptions) ``` #### model_name ```python model_name: str ``` The model name. # `pydantic_ai.models.function` A model controlled by a local function. FunctionModel is similar to [`TestModel`](../test/), but allows greater control over the model's behavior. Its primary use case is for more advanced unit testing than is possible with `TestModel`. Here's a minimal example: function_model_usage.py ```py from pydantic_ai import Agent from pydantic_ai.messages import ModelMessage, ModelResponse, TextPart from pydantic_ai.models.function import FunctionModel, AgentInfo my_agent = Agent('openai:gpt-4o') async def model_function( messages: list[ModelMessage], info: AgentInfo ) -> ModelResponse: print(messages) """ [ ModelRequest( parts=[ UserPromptPart( content='Testing my agent...', timestamp=datetime.datetime(...), ) ] ) ] """ print(info) """ AgentInfo( function_tools=[], allow_text_output=True, output_tools=[], model_settings=None ) """ return ModelResponse(parts=[TextPart('hello world')]) async def test_my_agent(): """Unit test for my_agent, to be run by pytest.""" with my_agent.override(model=FunctionModel(model_function)): result = await my_agent.run('Testing my agent...') assert result.output == 'hello world' ``` See [Unit testing with `FunctionModel`](../../../testing/#unit-testing-with-functionmodel) for detailed documentation. ### FunctionModel Bases: `Model` A model controlled by a local function. Apart from `__init__`, all methods are private or match those of the base class. Source code in `pydantic_ai_slim/pydantic_ai/models/function.py` ```python @dataclass(init=False) class FunctionModel(Model): """A model controlled by a local function. Apart from `__init__`, all methods are private or match those of the base class. """ function: FunctionDef | None = None stream_function: StreamFunctionDef | None = None _model_name: str = field(repr=False) _system: str = field(default='function', repr=False) @overload def __init__( self, function: FunctionDef, *, model_name: str | None = None, profile: ModelProfileSpec | None = None ) -> None: ... @overload def __init__( self, *, stream_function: StreamFunctionDef, model_name: str | None = None, profile: ModelProfileSpec | None = None, ) -> None: ... @overload def __init__( self, function: FunctionDef, *, stream_function: StreamFunctionDef, model_name: str | None = None, profile: ModelProfileSpec | None = None, ) -> None: ... def __init__( self, function: FunctionDef | None = None, *, stream_function: StreamFunctionDef | None = None, model_name: str | None = None, profile: ModelProfileSpec | None = None, ): """Initialize a `FunctionModel`. Either `function` or `stream_function` must be provided, providing both is allowed. Args: function: The function to call for non-streamed requests. stream_function: The function to call for streamed requests. model_name: The name of the model. If not provided, a name is generated from the function names. profile: The model profile to use. """ if function is None and stream_function is None: raise TypeError('Either `function` or `stream_function` must be provided') self.function = function self.stream_function = stream_function function_name = self.function.__name__ if self.function is not None else '' stream_function_name = self.stream_function.__name__ if self.stream_function is not None else '' self._model_name = model_name or f'function:{function_name}:{stream_function_name}' self._profile = profile async def request( self, messages: list[ModelMessage], model_settings: ModelSettings | None, model_request_parameters: ModelRequestParameters, ) -> ModelResponse: agent_info = AgentInfo( model_request_parameters.function_tools, model_request_parameters.allow_text_output, model_request_parameters.output_tools, model_settings, ) assert self.function is not None, 'FunctionModel must receive a `function` to support non-streamed requests' if inspect.iscoroutinefunction(self.function): response = await self.function(messages, agent_info) else: response_ = await _utils.run_in_executor(self.function, messages, agent_info) assert isinstance(response_, ModelResponse), response_ response = response_ response.model_name = self._model_name # Add usage data if not already present if not response.usage.has_values(): # pragma: no branch response.usage = _estimate_usage(chain(messages, [response])) response.usage.requests = 1 return response @asynccontextmanager async def request_stream( self, messages: list[ModelMessage], model_settings: ModelSettings | None, model_request_parameters: ModelRequestParameters, ) -> AsyncIterator[StreamedResponse]: agent_info = AgentInfo( model_request_parameters.function_tools, model_request_parameters.allow_text_output, model_request_parameters.output_tools, model_settings, ) assert self.stream_function is not None, ( 'FunctionModel must receive a `stream_function` to support streamed requests' ) response_stream = PeekableAsyncStream(self.stream_function(messages, agent_info)) first = await response_stream.peek() if isinstance(first, _utils.Unset): raise ValueError('Stream function must return at least one item') yield FunctionStreamedResponse(_model_name=self._model_name, _iter=response_stream) @property def model_name(self) -> str: """The model name.""" return self._model_name @property def system(self) -> str: """The system / model provider.""" return self._system ``` #### __init__ ```python __init__( function: FunctionDef, *, model_name: str | None = None, profile: ModelProfileSpec | None = None ) -> None ``` ```python __init__( *, stream_function: StreamFunctionDef, model_name: str | None = None, profile: ModelProfileSpec | None = None ) -> None ``` ```python __init__( function: FunctionDef, *, stream_function: StreamFunctionDef, model_name: str | None = None, profile: ModelProfileSpec | None = None ) -> None ``` ```python __init__( function: FunctionDef | None = None, *, stream_function: StreamFunctionDef | None = None, model_name: str | None = None, profile: ModelProfileSpec | None = None ) ``` Initialize a `FunctionModel`. Either `function` or `stream_function` must be provided, providing both is allowed. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `function` | `FunctionDef | None` | The function to call for non-streamed requests. | `None` | | `stream_function` | `StreamFunctionDef | None` | The function to call for streamed requests. | `None` | | `model_name` | `str | None` | The name of the model. If not provided, a name is generated from the function names. | `None` | | `profile` | `ModelProfileSpec | None` | The model profile to use. | `None` | Source code in `pydantic_ai_slim/pydantic_ai/models/function.py` ```python def __init__( self, function: FunctionDef | None = None, *, stream_function: StreamFunctionDef | None = None, model_name: str | None = None, profile: ModelProfileSpec | None = None, ): """Initialize a `FunctionModel`. Either `function` or `stream_function` must be provided, providing both is allowed. Args: function: The function to call for non-streamed requests. stream_function: The function to call for streamed requests. model_name: The name of the model. If not provided, a name is generated from the function names. profile: The model profile to use. """ if function is None and stream_function is None: raise TypeError('Either `function` or `stream_function` must be provided') self.function = function self.stream_function = stream_function function_name = self.function.__name__ if self.function is not None else '' stream_function_name = self.stream_function.__name__ if self.stream_function is not None else '' self._model_name = model_name or f'function:{function_name}:{stream_function_name}' self._profile = profile ``` #### model_name ```python model_name: str ``` The model name. #### system ```python system: str ``` The system / model provider. ### AgentInfo Information about an agent. This is passed as the second to functions used within FunctionModel. Source code in `pydantic_ai_slim/pydantic_ai/models/function.py` ```python @dataclass(frozen=True) class AgentInfo: """Information about an agent. This is passed as the second to functions used within [`FunctionModel`][pydantic_ai.models.function.FunctionModel]. """ function_tools: list[ToolDefinition] """The function tools available on this agent. These are the tools registered via the [`tool`][pydantic_ai.Agent.tool] and [`tool_plain`][pydantic_ai.Agent.tool_plain] decorators. """ allow_text_output: bool """Whether a plain text output is allowed.""" output_tools: list[ToolDefinition] """The tools that can called to produce the final output of the run.""" model_settings: ModelSettings | None """The model settings passed to the run call.""" ``` #### function_tools ```python function_tools: list[ToolDefinition] ``` The function tools available on this agent. These are the tools registered via the tool and tool_plain decorators. #### allow_text_output ```python allow_text_output: bool ``` Whether a plain text output is allowed. #### output_tools ```python output_tools: list[ToolDefinition] ``` The tools that can called to produce the final output of the run. #### model_settings ```python model_settings: ModelSettings | None ``` The model settings passed to the run call. ### DeltaToolCall Incremental change to a tool call. Used to describe a chunk when streaming structured responses. Source code in `pydantic_ai_slim/pydantic_ai/models/function.py` ```python @dataclass class DeltaToolCall: """Incremental change to a tool call. Used to describe a chunk when streaming structured responses. """ name: str | None = None """Incremental change to the name of the tool.""" json_args: str | None = None """Incremental change to the arguments as JSON""" tool_call_id: str | None = None """Incremental change to the tool call ID.""" ``` #### name ```python name: str | None = None ``` Incremental change to the name of the tool. #### json_args ```python json_args: str | None = None ``` Incremental change to the arguments as JSON #### tool_call_id ```python tool_call_id: str | None = None ``` Incremental change to the tool call ID. ### DeltaToolCalls ```python DeltaToolCalls: TypeAlias = dict[int, DeltaToolCall] ``` A mapping of tool call IDs to incremental changes. ### FunctionDef ```python FunctionDef: TypeAlias = Callable[ [list[ModelMessage], AgentInfo], Union[ModelResponse, Awaitable[ModelResponse]], ] ``` A function used to generate a non-streamed response. ### StreamFunctionDef ```python StreamFunctionDef: TypeAlias = Callable[ [list[ModelMessage], AgentInfo], AsyncIterator[Union[str, DeltaToolCalls]], ] ``` A function used to generate a streamed response. While this is defined as having return type of `AsyncIterator[Union[str, DeltaToolCalls]]`, it should really be considered as `Union[AsyncIterator[str], AsyncIterator[DeltaToolCalls]`, E.g. you need to yield all text or all `DeltaToolCalls`, not mix them. ### FunctionStreamedResponse Bases: `StreamedResponse` Implementation of `StreamedResponse` for FunctionModel. Source code in `pydantic_ai_slim/pydantic_ai/models/function.py` ```python @dataclass class FunctionStreamedResponse(StreamedResponse): """Implementation of `StreamedResponse` for [FunctionModel][pydantic_ai.models.function.FunctionModel].""" _model_name: str _iter: AsyncIterator[str | DeltaToolCalls] _timestamp: datetime = field(default_factory=_utils.now_utc) def __post_init__(self): self._usage += _estimate_usage([]) async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: async for item in self._iter: if isinstance(item, str): response_tokens = _estimate_string_tokens(item) self._usage += usage.Usage(response_tokens=response_tokens, total_tokens=response_tokens) yield self._parts_manager.handle_text_delta(vendor_part_id='content', content=item) else: delta_tool_calls = item for dtc_index, delta_tool_call in delta_tool_calls.items(): if delta_tool_call.json_args: response_tokens = _estimate_string_tokens(delta_tool_call.json_args) self._usage += usage.Usage(response_tokens=response_tokens, total_tokens=response_tokens) maybe_event = self._parts_manager.handle_tool_call_delta( vendor_part_id=dtc_index, tool_name=delta_tool_call.name, args=delta_tool_call.json_args, tool_call_id=delta_tool_call.tool_call_id, ) if maybe_event is not None: yield maybe_event @property def model_name(self) -> str: """Get the model name of the response.""" return self._model_name @property def timestamp(self) -> datetime: """Get the timestamp of the response.""" return self._timestamp ``` #### model_name ```python model_name: str ``` Get the model name of the response. #### timestamp ```python timestamp: datetime ``` Get the timestamp of the response. # `pydantic_ai.models.gemini` Custom interface to the `generativelanguage.googleapis.com` API using [HTTPX](https://www.python-httpx.org/) and [Pydantic](https://docs.pydantic.dev/latest/). The Google SDK for interacting with the `generativelanguage.googleapis.com` API [`google-generativeai`](https://ai.google.dev/gemini-api/docs/quickstart?lang=python) reads like it was written by a Java developer who thought they knew everything about OOP, spent 30 minutes trying to learn Python, gave up and decided to build the library to prove how horrible Python is. It also doesn't use httpx for HTTP requests, and tries to implement tool calling itself, but doesn't use Pydantic or equivalent for validation. We therefore implement support for the API directly. Despite these shortcomings, the Gemini model is actually quite powerful and very fast. ## Setup For details on how to set up authentication with this model, see [model configuration for Gemini](../../../models/gemini/). ### LatestGeminiModelNames ```python LatestGeminiModelNames = Literal[ "gemini-1.5-flash", "gemini-1.5-flash-8b", "gemini-1.5-pro", "gemini-1.0-pro", "gemini-2.0-flash", "gemini-2.0-flash-lite-preview-02-05", "gemini-2.0-pro-exp-02-05", "gemini-2.5-flash-preview-05-20", "gemini-2.5-flash", "gemini-2.5-flash-lite-preview-06-17", "gemini-2.5-pro-exp-03-25", "gemini-2.5-pro-preview-05-06", "gemini-2.5-pro", ] ``` Latest Gemini models. ### GeminiModelName ```python GeminiModelName = Union[str, LatestGeminiModelNames] ``` Possible Gemini model names. Since Gemini supports a variety of date-stamped models, we explicitly list the latest models but allow any name in the type hints. See [the Gemini API docs](https://ai.google.dev/gemini-api/docs/models/gemini#model-variations) for a full list. ### GeminiModelSettings Bases: `ModelSettings` Settings used for a Gemini model request. Source code in `pydantic_ai_slim/pydantic_ai/models/gemini.py` ```python class GeminiModelSettings(ModelSettings, total=False): """Settings used for a Gemini model request.""" # ALL FIELDS MUST BE `gemini_` PREFIXED SO YOU CAN MERGE THEM WITH OTHER MODELS. gemini_safety_settings: list[GeminiSafetySettings] """Safety settings options for Gemini model request.""" gemini_thinking_config: ThinkingConfig """Thinking is "on" by default in both the API and AI Studio. Being on by default doesn't mean the model will send back thoughts. For that, you would need to set `include_thoughts` to `True`, but since end of January 2025, `thoughts` are not returned anymore, and are only displayed in the Google AI Studio. See https://discuss.ai.google.dev/t/thoughts-are-missing-cot-not-included-anymore/63653 for more details. If you want to avoid the model spending any tokens on thinking, you can set `thinking_budget` to `0`. See more about it on . """ gemini_labels: dict[str, str] """User-defined metadata to break down billed charges. Only supported by the Vertex AI provider. See the [Gemini API docs](https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/add-labels-to-api-calls) for use cases and limitations. """ gemini_thinking_config: ThinkingConfig """Thinking is on by default in both the API and AI Studio. Being on by default doesn't mean the model will send back thoughts. For that, you need to set `include_thoughts` to `True`. If you want to turn it off, set `thinking_budget` to `0`. See more about it on . """ ``` #### gemini_safety_settings ```python gemini_safety_settings: list[GeminiSafetySettings] ``` Safety settings options for Gemini model request. #### gemini_labels ```python gemini_labels: dict[str, str] ``` User-defined metadata to break down billed charges. Only supported by the Vertex AI provider. See the [Gemini API docs](https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/add-labels-to-api-calls) for use cases and limitations. #### gemini_thinking_config ```python gemini_thinking_config: ThinkingConfig ``` Thinking is on by default in both the API and AI Studio. Being on by default doesn't mean the model will send back thoughts. For that, you need to set `include_thoughts` to `True`. If you want to turn it off, set `thinking_budget` to `0`. See more about it on . ### GeminiModel Bases: `Model` A model that uses Gemini via `generativelanguage.googleapis.com` API. This is implemented from scratch rather than using a dedicated SDK, good API documentation is available [here](https://ai.google.dev/api). Apart from `__init__`, all methods are private or match those of the base class. Source code in `pydantic_ai_slim/pydantic_ai/models/gemini.py` ```python @dataclass(init=False) class GeminiModel(Model): """A model that uses Gemini via `generativelanguage.googleapis.com` API. This is implemented from scratch rather than using a dedicated SDK, good API documentation is available [here](https://ai.google.dev/api). Apart from `__init__`, all methods are private or match those of the base class. """ client: httpx.AsyncClient = field(repr=False) _model_name: GeminiModelName = field(repr=False) _provider: Literal['google-gla', 'google-vertex'] | Provider[httpx.AsyncClient] | None = field(repr=False) _auth: AuthProtocol | None = field(repr=False) _url: str | None = field(repr=False) _system: str = field(default='gemini', repr=False) def __init__( self, model_name: GeminiModelName, *, provider: Literal['google-gla', 'google-vertex'] | Provider[httpx.AsyncClient] = 'google-gla', profile: ModelProfileSpec | None = None, ): """Initialize a Gemini model. Args: model_name: The name of the model to use. provider: The provider to use for authentication and API access. Can be either the string 'google-gla' or 'google-vertex' or an instance of `Provider[httpx.AsyncClient]`. If not provided, a new provider will be created using the other parameters. profile: The model profile to use. Defaults to a profile picked by the provider based on the model name. """ self._model_name = model_name self._provider = provider if isinstance(provider, str): provider = infer_provider(provider) self._system = provider.name self.client = provider.client self._url = str(self.client.base_url) self._profile = profile or provider.model_profile @property def base_url(self) -> str: assert self._url is not None, 'URL not initialized' # pragma: no cover return self._url # pragma: no cover async def request( self, messages: list[ModelMessage], model_settings: ModelSettings | None, model_request_parameters: ModelRequestParameters, ) -> ModelResponse: check_allow_model_requests() async with self._make_request( messages, False, cast(GeminiModelSettings, model_settings or {}), model_request_parameters ) as http_response: data = await http_response.aread() response = _gemini_response_ta.validate_json(data) return self._process_response(response) @asynccontextmanager async def request_stream( self, messages: list[ModelMessage], model_settings: ModelSettings | None, model_request_parameters: ModelRequestParameters, ) -> AsyncIterator[StreamedResponse]: check_allow_model_requests() async with self._make_request( messages, True, cast(GeminiModelSettings, model_settings or {}), model_request_parameters ) as http_response: yield await self._process_streamed_response(http_response) @property def model_name(self) -> GeminiModelName: """The model name.""" return self._model_name @property def system(self) -> str: """The system / model provider.""" return self._system def _get_tools(self, model_request_parameters: ModelRequestParameters) -> _GeminiTools | None: tools = [_function_from_abstract_tool(t) for t in model_request_parameters.function_tools] if model_request_parameters.output_tools: tools += [_function_from_abstract_tool(t) for t in model_request_parameters.output_tools] return _GeminiTools(function_declarations=tools) if tools else None def _get_tool_config( self, model_request_parameters: ModelRequestParameters, tools: _GeminiTools | None ) -> _GeminiToolConfig | None: if not model_request_parameters.allow_text_output and tools: return _tool_config([t['name'] for t in tools['function_declarations']]) else: return None @asynccontextmanager async def _make_request( self, messages: list[ModelMessage], streamed: bool, model_settings: GeminiModelSettings, model_request_parameters: ModelRequestParameters, ) -> AsyncIterator[HTTPResponse]: tools = self._get_tools(model_request_parameters) tool_config = self._get_tool_config(model_request_parameters, tools) sys_prompt_parts, contents = await self._message_to_gemini_content(messages) request_data = _GeminiRequest(contents=contents) if sys_prompt_parts: request_data['systemInstruction'] = _GeminiTextContent(role='user', parts=sys_prompt_parts) if tools is not None: request_data['tools'] = tools if tool_config is not None: request_data['toolConfig'] = tool_config generation_config = _settings_to_generation_config(model_settings) if model_request_parameters.output_mode == 'native': if tools: raise UserError('Gemini does not support structured output and tools at the same time.') generation_config['response_mime_type'] = 'application/json' output_object = model_request_parameters.output_object assert output_object is not None generation_config['response_schema'] = self._map_response_schema(output_object) elif model_request_parameters.output_mode == 'prompted' and not tools: generation_config['response_mime_type'] = 'application/json' if generation_config: request_data['generationConfig'] = generation_config if gemini_safety_settings := model_settings.get('gemini_safety_settings'): request_data['safetySettings'] = gemini_safety_settings if gemini_labels := model_settings.get('gemini_labels'): if self._system == 'google-vertex': request_data['labels'] = gemini_labels # pragma: lax no cover headers = {'Content-Type': 'application/json', 'User-Agent': get_user_agent()} url = f'/{self._model_name}:{"streamGenerateContent" if streamed else "generateContent"}' request_json = _gemini_request_ta.dump_json(request_data, by_alias=True) async with self.client.stream( 'POST', url, content=request_json, headers=headers, timeout=model_settings.get('timeout', USE_CLIENT_DEFAULT), ) as r: if (status_code := r.status_code) != 200: await r.aread() if status_code >= 400: raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=r.text) raise UnexpectedModelBehavior( # pragma: no cover f'Unexpected response from gemini {status_code}', r.text ) yield r def _process_response(self, response: _GeminiResponse) -> ModelResponse: vendor_details: dict[str, Any] | None = None if len(response['candidates']) != 1: raise UnexpectedModelBehavior('Expected exactly one candidate in Gemini response') # pragma: no cover if 'content' not in response['candidates'][0]: if response['candidates'][0].get('finish_reason') == 'SAFETY': raise UnexpectedModelBehavior('Safety settings triggered', str(response)) else: raise UnexpectedModelBehavior( # pragma: no cover 'Content field missing from Gemini response', str(response) ) parts = response['candidates'][0]['content']['parts'] vendor_id = response.get('vendor_id', None) finish_reason = response['candidates'][0].get('finish_reason') if finish_reason: vendor_details = {'finish_reason': finish_reason} usage = _metadata_as_usage(response) usage.requests = 1 return _process_response_from_parts( parts, response.get('model_version', self._model_name), usage, vendor_id=vendor_id, vendor_details=vendor_details, ) async def _process_streamed_response(self, http_response: HTTPResponse) -> StreamedResponse: """Process a streamed response, and prepare a streaming response to return.""" aiter_bytes = http_response.aiter_bytes() start_response: _GeminiResponse | None = None content = bytearray() async for chunk in aiter_bytes: content.extend(chunk) responses = _gemini_streamed_response_ta.validate_json( _ensure_decodeable(content), experimental_allow_partial='trailing-strings', ) if responses: # pragma: no branch last = responses[-1] if last['candidates'] and last['candidates'][0].get('content', {}).get('parts'): start_response = last break if start_response is None: raise UnexpectedModelBehavior('Streamed response ended without content or tool calls') return GeminiStreamedResponse(_model_name=self._model_name, _content=content, _stream=aiter_bytes) async def _message_to_gemini_content( self, messages: list[ModelMessage] ) -> tuple[list[_GeminiTextPart], list[_GeminiContent]]: sys_prompt_parts: list[_GeminiTextPart] = [] contents: list[_GeminiContent] = [] for m in messages: if isinstance(m, ModelRequest): message_parts: list[_GeminiPartUnion] = [] for part in m.parts: if isinstance(part, SystemPromptPart): sys_prompt_parts.append(_GeminiTextPart(text=part.content)) elif isinstance(part, UserPromptPart): message_parts.extend(await self._map_user_prompt(part)) elif isinstance(part, ToolReturnPart): message_parts.append(_response_part_from_response(part.tool_name, part.model_response_object())) elif isinstance(part, RetryPromptPart): if part.tool_name is None: message_parts.append(_GeminiTextPart(text=part.model_response())) # pragma: no cover else: response = {'call_error': part.model_response()} message_parts.append(_response_part_from_response(part.tool_name, response)) else: assert_never(part) if message_parts: # pragma: no branch contents.append(_GeminiContent(role='user', parts=message_parts)) elif isinstance(m, ModelResponse): contents.append(_content_model_response(m)) else: assert_never(m) if instructions := self._get_instructions(messages): sys_prompt_parts.insert(0, _GeminiTextPart(text=instructions)) return sys_prompt_parts, contents async def _map_user_prompt(self, part: UserPromptPart) -> list[_GeminiPartUnion]: if isinstance(part.content, str): return [{'text': part.content}] else: content: list[_GeminiPartUnion] = [] for item in part.content: if isinstance(item, str): content.append({'text': item}) elif isinstance(item, BinaryContent): base64_encoded = base64.b64encode(item.data).decode('utf-8') content.append( _GeminiInlineDataPart(inline_data={'data': base64_encoded, 'mime_type': item.media_type}) ) elif isinstance(item, VideoUrl) and item.is_youtube: file_data = _GeminiFileDataPart(file_data={'file_uri': item.url, 'mime_type': item.media_type}) content.append(file_data) elif isinstance(item, FileUrl): if self.system == 'google-gla' or item.force_download: downloaded_item = await download_item(item, data_format='base64') inline_data = _GeminiInlineDataPart( inline_data={'data': downloaded_item['data'], 'mime_type': downloaded_item['data_type']} ) content.append(inline_data) else: file_data = _GeminiFileDataPart(file_data={'file_uri': item.url, 'mime_type': item.media_type}) content.append(file_data) else: assert_never(item) return content def _map_response_schema(self, o: OutputObjectDefinition) -> dict[str, Any]: response_schema = o.json_schema.copy() if o.name: response_schema['title'] = o.name if o.description: response_schema['description'] = o.description return response_schema ``` #### __init__ ```python __init__( model_name: GeminiModelName, *, provider: ( Literal["google-gla", "google-vertex"] | Provider[AsyncClient] ) = "google-gla", profile: ModelProfileSpec | None = None ) ``` Initialize a Gemini model. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `model_name` | `GeminiModelName` | The name of the model to use. | *required* | | `provider` | `Literal['google-gla', 'google-vertex'] | Provider[AsyncClient]` | The provider to use for authentication and API access. Can be either the string 'google-gla' or 'google-vertex' or an instance of Provider[httpx.AsyncClient]. If not provided, a new provider will be created using the other parameters. | `'google-gla'` | | `profile` | `ModelProfileSpec | None` | The model profile to use. Defaults to a profile picked by the provider based on the model name. | `None` | Source code in `pydantic_ai_slim/pydantic_ai/models/gemini.py` ```python def __init__( self, model_name: GeminiModelName, *, provider: Literal['google-gla', 'google-vertex'] | Provider[httpx.AsyncClient] = 'google-gla', profile: ModelProfileSpec | None = None, ): """Initialize a Gemini model. Args: model_name: The name of the model to use. provider: The provider to use for authentication and API access. Can be either the string 'google-gla' or 'google-vertex' or an instance of `Provider[httpx.AsyncClient]`. If not provided, a new provider will be created using the other parameters. profile: The model profile to use. Defaults to a profile picked by the provider based on the model name. """ self._model_name = model_name self._provider = provider if isinstance(provider, str): provider = infer_provider(provider) self._system = provider.name self.client = provider.client self._url = str(self.client.base_url) self._profile = profile or provider.model_profile ``` #### model_name ```python model_name: GeminiModelName ``` The model name. #### system ```python system: str ``` The system / model provider. ### AuthProtocol Bases: `Protocol` Abstract definition for Gemini authentication. Source code in `pydantic_ai_slim/pydantic_ai/models/gemini.py` ```python class AuthProtocol(Protocol): """Abstract definition for Gemini authentication.""" async def headers(self) -> dict[str, str]: ... ``` ### ApiKeyAuth Authentication using an API key for the `X-Goog-Api-Key` header. Source code in `pydantic_ai_slim/pydantic_ai/models/gemini.py` ```python @dataclass class ApiKeyAuth: """Authentication using an API key for the `X-Goog-Api-Key` header.""" api_key: str async def headers(self) -> dict[str, str]: # https://cloud.google.com/docs/authentication/api-keys-use#using-with-rest return {'X-Goog-Api-Key': self.api_key} # pragma: no cover ``` ### GeminiStreamedResponse Bases: `StreamedResponse` Implementation of `StreamedResponse` for the Gemini model. Source code in `pydantic_ai_slim/pydantic_ai/models/gemini.py` ```python @dataclass class GeminiStreamedResponse(StreamedResponse): """Implementation of `StreamedResponse` for the Gemini model.""" _model_name: GeminiModelName _content: bytearray _stream: AsyncIterator[bytes] _timestamp: datetime = field(default_factory=_utils.now_utc, init=False) async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: async for gemini_response in self._get_gemini_responses(): candidate = gemini_response['candidates'][0] if 'content' not in candidate: raise UnexpectedModelBehavior('Streamed response has no content field') # pragma: no cover gemini_part: _GeminiPartUnion for gemini_part in candidate['content']['parts']: if 'text' in gemini_part: # Using vendor_part_id=None means we can produce multiple text parts if their deltas are sprinkled # amongst the tool call deltas yield self._parts_manager.handle_text_delta(vendor_part_id=None, content=gemini_part['text']) elif 'function_call' in gemini_part: # Here, we assume all function_call parts are complete and don't have deltas. # We do this by assigning a unique randomly generated "vendor_part_id". # We need to confirm whether this is actually true, but if it isn't, we can still handle it properly # it would just be a bit more complicated. And we'd need to confirm the intended semantics. maybe_event = self._parts_manager.handle_tool_call_delta( vendor_part_id=uuid4(), tool_name=gemini_part['function_call']['name'], args=gemini_part['function_call']['args'], tool_call_id=None, ) if maybe_event is not None: # pragma: no branch yield maybe_event else: if not any([key in gemini_part for key in ['function_response', 'thought']]): raise AssertionError(f'Unexpected part: {gemini_part}') # pragma: no cover async def _get_gemini_responses(self) -> AsyncIterator[_GeminiResponse]: # This method exists to ensure we only yield completed items, so we don't need to worry about # partial gemini responses, which would make everything more complicated gemini_responses: list[_GeminiResponse] = [] current_gemini_response_index = 0 # Right now, there are some circumstances where we will have information that could be yielded sooner than it is # But changing that would make things a lot more complicated. async for chunk in self._stream: self._content.extend(chunk) gemini_responses = _gemini_streamed_response_ta.validate_json( _ensure_decodeable(self._content), experimental_allow_partial='trailing-strings', ) # The idea: yield only up to the latest response, which might still be partial. # Note that if the latest response is complete, we could yield it immediately, but there's not a good # allow_partial API to determine if the last item in the list is complete. responses_to_yield = gemini_responses[:-1] for r in responses_to_yield[current_gemini_response_index:]: current_gemini_response_index += 1 yield r # Now yield the final response, which should be complete if gemini_responses: # pragma: no branch r = gemini_responses[-1] self._usage = _metadata_as_usage(r) yield r @property def model_name(self) -> GeminiModelName: """Get the model name of the response.""" return self._model_name @property def timestamp(self) -> datetime: """Get the timestamp of the response.""" return self._timestamp ``` #### model_name ```python model_name: GeminiModelName ``` Get the model name of the response. #### timestamp ```python timestamp: datetime ``` Get the timestamp of the response. ### GeminiSafetySettings Bases: `TypedDict` Safety settings options for Gemini model request. See [Gemini API docs](https://ai.google.dev/gemini-api/docs/safety-settings) for safety category and threshold descriptions. For an example on how to use `GeminiSafetySettings`, see [here](../../../agents/#model-specific-settings). Source code in `pydantic_ai_slim/pydantic_ai/models/gemini.py` ```python class GeminiSafetySettings(TypedDict): """Safety settings options for Gemini model request. See [Gemini API docs](https://ai.google.dev/gemini-api/docs/safety-settings) for safety category and threshold descriptions. For an example on how to use `GeminiSafetySettings`, see [here](../../agents.md#model-specific-settings). """ category: Literal[ 'HARM_CATEGORY_UNSPECIFIED', 'HARM_CATEGORY_HARASSMENT', 'HARM_CATEGORY_HATE_SPEECH', 'HARM_CATEGORY_SEXUALLY_EXPLICIT', 'HARM_CATEGORY_DANGEROUS_CONTENT', 'HARM_CATEGORY_CIVIC_INTEGRITY', ] """ Safety settings category. """ threshold: Literal[ 'HARM_BLOCK_THRESHOLD_UNSPECIFIED', 'BLOCK_LOW_AND_ABOVE', 'BLOCK_MEDIUM_AND_ABOVE', 'BLOCK_ONLY_HIGH', 'BLOCK_NONE', 'OFF', ] """ Safety settings threshold. """ ``` #### category ```python category: Literal[ "HARM_CATEGORY_UNSPECIFIED", "HARM_CATEGORY_HARASSMENT", "HARM_CATEGORY_HATE_SPEECH", "HARM_CATEGORY_SEXUALLY_EXPLICIT", "HARM_CATEGORY_DANGEROUS_CONTENT", "HARM_CATEGORY_CIVIC_INTEGRITY", ] ``` Safety settings category. #### threshold ```python threshold: Literal[ "HARM_BLOCK_THRESHOLD_UNSPECIFIED", "BLOCK_LOW_AND_ABOVE", "BLOCK_MEDIUM_AND_ABOVE", "BLOCK_ONLY_HIGH", "BLOCK_NONE", "OFF", ] ``` Safety settings threshold. ### ThinkingConfig Bases: `TypedDict` The thinking features configuration. Source code in `pydantic_ai_slim/pydantic_ai/models/gemini.py` ```python class ThinkingConfig(TypedDict, total=False): """The thinking features configuration.""" include_thoughts: Annotated[bool, pydantic.Field(alias='includeThoughts')] """Indicates whether to include thoughts in the response. If true, thoughts are returned only if the model supports thought and thoughts are available.""" thinking_budget: Annotated[int, pydantic.Field(alias='thinkingBudget')] """Indicates the thinking budget in tokens.""" ``` #### include_thoughts ```python include_thoughts: Annotated[ bool, Field(alias=includeThoughts) ] ``` Indicates whether to include thoughts in the response. If true, thoughts are returned only if the model supports thought and thoughts are available. #### thinking_budget ```python thinking_budget: Annotated[int, Field(alias=thinkingBudget)] ``` Indicates the thinking budget in tokens. # `pydantic_ai.models.google` ## Setup For details on how to set up authentication with this model, see [model configuration for Google](../../../models/google/). ### LatestGoogleModelNames ```python LatestGoogleModelNames = Literal[ "gemini-1.5-flash", "gemini-1.5-flash-8b", "gemini-1.5-pro", "gemini-1.0-pro", "gemini-2.0-flash", "gemini-2.0-flash-lite-preview-02-05", "gemini-2.0-pro-exp-02-05", "gemini-2.5-flash-preview-05-20", "gemini-2.5-flash", "gemini-2.5-flash-lite-preview-06-17", "gemini-2.5-pro-exp-03-25", "gemini-2.5-pro-preview-05-06", "gemini-2.5-pro", ] ``` Latest Gemini models. ### GoogleModelName ```python GoogleModelName = Union[str, LatestGoogleModelNames] ``` Possible Gemini model names. Since Gemini supports a variety of date-stamped models, we explicitly list the latest models but allow any name in the type hints. See [the Gemini API docs](https://ai.google.dev/gemini-api/docs/models/gemini#model-variations) for a full list. ### GoogleModelSettings Bases: `ModelSettings` Settings used for a Gemini model request. Source code in `pydantic_ai_slim/pydantic_ai/models/google.py` ```python class GoogleModelSettings(ModelSettings, total=False): """Settings used for a Gemini model request.""" # ALL FIELDS MUST BE `gemini_` PREFIXED SO YOU CAN MERGE THEM WITH OTHER MODELS. google_safety_settings: list[SafetySettingDict] """The safety settings to use for the model. See for more information. """ google_thinking_config: ThinkingConfigDict """The thinking configuration to use for the model. See for more information. """ google_labels: dict[str, str] """User-defined metadata to break down billed charges. Only supported by the Vertex AI API. See the [Gemini API docs](https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/add-labels-to-api-calls) for use cases and limitations. """ google_video_resolution: MediaResolution """The video resolution to use for the model. See for more information. """ ``` #### google_safety_settings ```python google_safety_settings: list[SafetySettingDict] ``` The safety settings to use for the model. See for more information. #### google_thinking_config ```python google_thinking_config: ThinkingConfigDict ``` The thinking configuration to use for the model. See for more information. #### google_labels ```python google_labels: dict[str, str] ``` User-defined metadata to break down billed charges. Only supported by the Vertex AI API. See the [Gemini API docs](https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/add-labels-to-api-calls) for use cases and limitations. #### google_video_resolution ```python google_video_resolution: MediaResolution ``` The video resolution to use for the model. See for more information. ### GoogleModel Bases: `Model` A model that uses Gemini via `generativelanguage.googleapis.com` API. This is implemented from scratch rather than using a dedicated SDK, good API documentation is available [here](https://ai.google.dev/api). Apart from `__init__`, all methods are private or match those of the base class. Source code in `pydantic_ai_slim/pydantic_ai/models/google.py` ```python @dataclass(init=False) class GoogleModel(Model): """A model that uses Gemini via `generativelanguage.googleapis.com` API. This is implemented from scratch rather than using a dedicated SDK, good API documentation is available [here](https://ai.google.dev/api). Apart from `__init__`, all methods are private or match those of the base class. """ client: genai.Client = field(repr=False) _model_name: GoogleModelName = field(repr=False) _provider: Provider[genai.Client] = field(repr=False) _url: str | None = field(repr=False) _system: str = field(default='google', repr=False) def __init__( self, model_name: GoogleModelName, *, provider: Literal['google-gla', 'google-vertex'] | Provider[genai.Client] = 'google-gla', profile: ModelProfileSpec | None = None, ): """Initialize a Gemini model. Args: model_name: The name of the model to use. provider: The provider to use for authentication and API access. Can be either the string 'google-gla' or 'google-vertex' or an instance of `Provider[httpx.AsyncClient]`. If not provided, a new provider will be created using the other parameters. profile: The model profile to use. Defaults to a profile picked by the provider based on the model name. """ self._model_name = model_name if isinstance(provider, str): provider = GoogleProvider(vertexai=provider == 'google-vertex') # pragma: lax no cover self._provider = provider self._system = provider.name self.client = provider.client self._profile = profile or provider.model_profile @property def base_url(self) -> str: return self._provider.base_url async def request( self, messages: list[ModelMessage], model_settings: ModelSettings | None, model_request_parameters: ModelRequestParameters, ) -> ModelResponse: check_allow_model_requests() model_settings = cast(GoogleModelSettings, model_settings or {}) response = await self._generate_content(messages, False, model_settings, model_request_parameters) return self._process_response(response) @asynccontextmanager async def request_stream( self, messages: list[ModelMessage], model_settings: ModelSettings | None, model_request_parameters: ModelRequestParameters, ) -> AsyncIterator[StreamedResponse]: check_allow_model_requests() model_settings = cast(GoogleModelSettings, model_settings or {}) response = await self._generate_content(messages, True, model_settings, model_request_parameters) yield await self._process_streamed_response(response) # type: ignore @property def model_name(self) -> GoogleModelName: """The model name.""" return self._model_name @property def system(self) -> str: """The system / model provider.""" return self._system def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[ToolDict] | None: tools: list[ToolDict] = [ ToolDict(function_declarations=[_function_declaration_from_tool(t)]) for t in model_request_parameters.function_tools ] if model_request_parameters.output_tools: tools += [ ToolDict(function_declarations=[_function_declaration_from_tool(t)]) for t in model_request_parameters.output_tools ] return tools or None def _get_tool_config( self, model_request_parameters: ModelRequestParameters, tools: list[ToolDict] | None ) -> ToolConfigDict | None: if not model_request_parameters.allow_text_output and tools: names: list[str] = [] for tool in tools: for function_declaration in tool.get('function_declarations') or []: if name := function_declaration.get('name'): # pragma: no branch names.append(name) return _tool_config(names) else: return None @overload async def _generate_content( self, messages: list[ModelMessage], stream: Literal[False], model_settings: GoogleModelSettings, model_request_parameters: ModelRequestParameters, ) -> GenerateContentResponse: ... @overload async def _generate_content( self, messages: list[ModelMessage], stream: Literal[True], model_settings: GoogleModelSettings, model_request_parameters: ModelRequestParameters, ) -> Awaitable[AsyncIterator[GenerateContentResponse]]: ... async def _generate_content( self, messages: list[ModelMessage], stream: bool, model_settings: GoogleModelSettings, model_request_parameters: ModelRequestParameters, ) -> GenerateContentResponse | Awaitable[AsyncIterator[GenerateContentResponse]]: tools = self._get_tools(model_request_parameters) response_mime_type = None response_schema = None if model_request_parameters.output_mode == 'native': if tools: raise UserError('Gemini does not support structured output and tools at the same time.') response_mime_type = 'application/json' output_object = model_request_parameters.output_object assert output_object is not None response_schema = self._map_response_schema(output_object) elif model_request_parameters.output_mode == 'prompted' and not tools: response_mime_type = 'application/json' tool_config = self._get_tool_config(model_request_parameters, tools) system_instruction, contents = await self._map_messages(messages) http_options: HttpOptionsDict = { 'headers': {'Content-Type': 'application/json', 'User-Agent': get_user_agent()} } if timeout := model_settings.get('timeout'): if isinstance(timeout, (int, float)): http_options['timeout'] = int(1000 * timeout) else: raise UserError('Google does not support setting ModelSettings.timeout to a httpx.Timeout') config = GenerateContentConfigDict( http_options=http_options, system_instruction=system_instruction, temperature=model_settings.get('temperature'), top_p=model_settings.get('top_p'), max_output_tokens=model_settings.get('max_tokens'), stop_sequences=model_settings.get('stop_sequences'), presence_penalty=model_settings.get('presence_penalty'), frequency_penalty=model_settings.get('frequency_penalty'), safety_settings=model_settings.get('google_safety_settings'), thinking_config=model_settings.get('google_thinking_config'), labels=model_settings.get('google_labels'), media_resolution=model_settings.get('google_video_resolution'), tools=cast(ToolListUnionDict, tools), tool_config=tool_config, response_mime_type=response_mime_type, response_schema=response_schema, ) func = self.client.aio.models.generate_content_stream if stream else self.client.aio.models.generate_content return await func(model=self._model_name, contents=contents, config=config) # type: ignore def _process_response(self, response: GenerateContentResponse) -> ModelResponse: if not response.candidates or len(response.candidates) != 1: raise UnexpectedModelBehavior('Expected exactly one candidate in Gemini response') # pragma: no cover if response.candidates[0].content is None or response.candidates[0].content.parts is None: if response.candidates[0].finish_reason == 'SAFETY': raise UnexpectedModelBehavior('Safety settings triggered', str(response)) else: raise UnexpectedModelBehavior( 'Content field missing from Gemini response', str(response) ) # pragma: no cover parts = response.candidates[0].content.parts or [] vendor_id = response.response_id or None vendor_details: dict[str, Any] | None = None finish_reason = response.candidates[0].finish_reason if finish_reason: # pragma: no branch vendor_details = {'finish_reason': finish_reason.value} usage = _metadata_as_usage(response) usage.requests = 1 return _process_response_from_parts( parts, response.model_version or self._model_name, usage, vendor_id=vendor_id, vendor_details=vendor_details ) async def _process_streamed_response(self, response: AsyncIterator[GenerateContentResponse]) -> StreamedResponse: """Process a streamed response, and prepare a streaming response to return.""" peekable_response = _utils.PeekableAsyncStream(response) first_chunk = await peekable_response.peek() if isinstance(first_chunk, _utils.Unset): raise UnexpectedModelBehavior('Streamed response ended without content or tool calls') # pragma: no cover return GeminiStreamedResponse( _model_name=self._model_name, _response=peekable_response, _timestamp=first_chunk.create_time or _utils.now_utc(), ) async def _map_messages(self, messages: list[ModelMessage]) -> tuple[ContentDict | None, list[ContentUnionDict]]: contents: list[ContentUnionDict] = [] system_parts: list[PartDict] = [] for m in messages: if isinstance(m, ModelRequest): message_parts: list[PartDict] = [] for part in m.parts: if isinstance(part, SystemPromptPart): system_parts.append({'text': part.content}) elif isinstance(part, UserPromptPart): message_parts.extend(await self._map_user_prompt(part)) elif isinstance(part, ToolReturnPart): message_parts.append( { 'function_response': { 'name': part.tool_name, 'response': part.model_response_object(), 'id': part.tool_call_id, } } ) elif isinstance(part, RetryPromptPart): if part.tool_name is None: message_parts.append({'text': part.model_response()}) # pragma: no cover else: message_parts.append( { 'function_response': { 'name': part.tool_name, 'response': {'call_error': part.model_response()}, 'id': part.tool_call_id, } } ) else: assert_never(part) # Google GenAI requires at least one part in the message. if not message_parts: message_parts = [{'text': ''}] contents.append({'role': 'user', 'parts': message_parts}) elif isinstance(m, ModelResponse): contents.append(_content_model_response(m)) else: assert_never(m) if instructions := self._get_instructions(messages): system_parts.insert(0, {'text': instructions}) system_instruction = ContentDict(role='user', parts=system_parts) if system_parts else None return system_instruction, contents async def _map_user_prompt(self, part: UserPromptPart) -> list[PartDict]: if isinstance(part.content, str): return [{'text': part.content}] else: content: list[PartDict] = [] for item in part.content: if isinstance(item, str): content.append({'text': item}) elif isinstance(item, BinaryContent): # NOTE: The type from Google GenAI is incorrect, it should be `str`, not `bytes`. base64_encoded = base64.b64encode(item.data).decode('utf-8') inline_data_dict = {'inline_data': {'data': base64_encoded, 'mime_type': item.media_type}} if item.vendor_metadata: inline_data_dict['video_metadata'] = item.vendor_metadata content.append(inline_data_dict) # type: ignore elif isinstance(item, VideoUrl) and item.is_youtube: file_data_dict = {'file_data': {'file_uri': item.url, 'mime_type': item.media_type}} if item.vendor_metadata: file_data_dict['video_metadata'] = item.vendor_metadata content.append(file_data_dict) # type: ignore elif isinstance(item, FileUrl): if self.system == 'google-gla' or item.force_download: downloaded_item = await download_item(item, data_format='base64') inline_data = {'data': downloaded_item['data'], 'mime_type': downloaded_item['data_type']} content.append({'inline_data': inline_data}) # type: ignore else: content.append({'file_data': {'file_uri': item.url, 'mime_type': item.media_type}}) else: assert_never(item) return content def _map_response_schema(self, o: OutputObjectDefinition) -> dict[str, Any]: response_schema = o.json_schema.copy() if o.name: response_schema['title'] = o.name if o.description: response_schema['description'] = o.description return response_schema ``` #### __init__ ```python __init__( model_name: GoogleModelName, *, provider: ( Literal["google-gla", "google-vertex"] | Provider[Client] ) = "google-gla", profile: ModelProfileSpec | None = None ) ``` Initialize a Gemini model. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `model_name` | `GoogleModelName` | The name of the model to use. | *required* | | `provider` | `Literal['google-gla', 'google-vertex'] | Provider[Client]` | The provider to use for authentication and API access. Can be either the string 'google-gla' or 'google-vertex' or an instance of Provider[httpx.AsyncClient]. If not provided, a new provider will be created using the other parameters. | `'google-gla'` | | `profile` | `ModelProfileSpec | None` | The model profile to use. Defaults to a profile picked by the provider based on the model name. | `None` | Source code in `pydantic_ai_slim/pydantic_ai/models/google.py` ```python def __init__( self, model_name: GoogleModelName, *, provider: Literal['google-gla', 'google-vertex'] | Provider[genai.Client] = 'google-gla', profile: ModelProfileSpec | None = None, ): """Initialize a Gemini model. Args: model_name: The name of the model to use. provider: The provider to use for authentication and API access. Can be either the string 'google-gla' or 'google-vertex' or an instance of `Provider[httpx.AsyncClient]`. If not provided, a new provider will be created using the other parameters. profile: The model profile to use. Defaults to a profile picked by the provider based on the model name. """ self._model_name = model_name if isinstance(provider, str): provider = GoogleProvider(vertexai=provider == 'google-vertex') # pragma: lax no cover self._provider = provider self._system = provider.name self.client = provider.client self._profile = profile or provider.model_profile ``` #### model_name ```python model_name: GoogleModelName ``` The model name. #### system ```python system: str ``` The system / model provider. ### GeminiStreamedResponse Bases: `StreamedResponse` Implementation of `StreamedResponse` for the Gemini model. Source code in `pydantic_ai_slim/pydantic_ai/models/google.py` ```python @dataclass class GeminiStreamedResponse(StreamedResponse): """Implementation of `StreamedResponse` for the Gemini model.""" _model_name: GoogleModelName _response: AsyncIterator[GenerateContentResponse] _timestamp: datetime async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: async for chunk in self._response: self._usage = _metadata_as_usage(chunk) assert chunk.candidates is not None candidate = chunk.candidates[0] if candidate.content is None: raise UnexpectedModelBehavior('Streamed response has no content field') # pragma: no cover assert candidate.content.parts is not None for part in candidate.content.parts: if part.text is not None: if part.thought: yield self._parts_manager.handle_thinking_delta(vendor_part_id='thinking', content=part.text) else: yield self._parts_manager.handle_text_delta(vendor_part_id='content', content=part.text) elif part.function_call: maybe_event = self._parts_manager.handle_tool_call_delta( vendor_part_id=uuid4(), tool_name=part.function_call.name, args=part.function_call.args, tool_call_id=part.function_call.id, ) if maybe_event is not None: # pragma: no branch yield maybe_event else: assert part.function_response is not None, f'Unexpected part: {part}' # pragma: no cover @property def model_name(self) -> GoogleModelName: """Get the model name of the response.""" return self._model_name @property def timestamp(self) -> datetime: """Get the timestamp of the response.""" return self._timestamp ``` #### model_name ```python model_name: GoogleModelName ``` Get the model name of the response. #### timestamp ```python timestamp: datetime ``` Get the timestamp of the response. # `pydantic_ai.models.groq` ## Setup For details on how to set up authentication with this model, see [model configuration for Groq](../../../models/groq/). ### ProductionGroqModelNames ```python ProductionGroqModelNames = Literal[ "distil-whisper-large-v3-en", "gemma2-9b-it", "llama-3.3-70b-versatile", "llama-3.1-8b-instant", "llama-guard-3-8b", "llama3-70b-8192", "llama3-8b-8192", "whisper-large-v3", "whisper-large-v3-turbo", ] ``` Production Groq models from . ### PreviewGroqModelNames ```python PreviewGroqModelNames = Literal[ "playai-tts", "playai-tts-arabic", "qwen-qwq-32b", "mistral-saba-24b", "qwen-2.5-coder-32b", "qwen-2.5-32b", "deepseek-r1-distill-qwen-32b", "deepseek-r1-distill-llama-70b", "llama-3.3-70b-specdec", "llama-3.2-1b-preview", "llama-3.2-3b-preview", "llama-3.2-11b-vision-preview", "llama-3.2-90b-vision-preview", ] ``` Preview Groq models from . ### GroqModelName ```python GroqModelName = Union[ str, ProductionGroqModelNames, PreviewGroqModelNames ] ``` Possible Groq model names. Since Groq supports a variety of models and the list changes frequencly, we explicitly list the named models as of 2025-03-31 but allow any name in the type hints. See for an up to date date list of models and more details. ### GroqModelSettings Bases: `ModelSettings` Settings used for a Groq model request. Source code in `pydantic_ai_slim/pydantic_ai/models/groq.py` ```python class GroqModelSettings(ModelSettings, total=False): """Settings used for a Groq model request.""" # ALL FIELDS MUST BE `groq_` PREFIXED SO YOU CAN MERGE THEM WITH OTHER MODELS. groq_reasoning_format: Literal['hidden', 'raw', 'parsed'] ``` ### GroqModel Bases: `Model` A model that uses the Groq API. Internally, this uses the [Groq Python client](https://github.com/groq/groq-python) to interact with the API. Apart from `__init__`, all methods are private or match those of the base class. Source code in `pydantic_ai_slim/pydantic_ai/models/groq.py` ```python @dataclass(init=False) class GroqModel(Model): """A model that uses the Groq API. Internally, this uses the [Groq Python client](https://github.com/groq/groq-python) to interact with the API. Apart from `__init__`, all methods are private or match those of the base class. """ client: AsyncGroq = field(repr=False) _model_name: GroqModelName = field(repr=False) _system: str = field(default='groq', repr=False) def __init__( self, model_name: GroqModelName, *, provider: Literal['groq'] | Provider[AsyncGroq] = 'groq', profile: ModelProfileSpec | None = None, ): """Initialize a Groq model. Args: model_name: The name of the Groq model to use. List of model names available [here](https://console.groq.com/docs/models). provider: The provider to use for authentication and API access. Can be either the string 'groq' or an instance of `Provider[AsyncGroq]`. If not provided, a new provider will be created using the other parameters. profile: The model profile to use. Defaults to a profile picked by the provider based on the model name. """ self._model_name = model_name if isinstance(provider, str): provider = infer_provider(provider) self.client = provider.client self._profile = profile or provider.model_profile @property def base_url(self) -> str: return str(self.client.base_url) async def request( self, messages: list[ModelMessage], model_settings: ModelSettings | None, model_request_parameters: ModelRequestParameters, ) -> ModelResponse: check_allow_model_requests() response = await self._completions_create( messages, False, cast(GroqModelSettings, model_settings or {}), model_request_parameters ) model_response = self._process_response(response) model_response.usage.requests = 1 return model_response @asynccontextmanager async def request_stream( self, messages: list[ModelMessage], model_settings: ModelSettings | None, model_request_parameters: ModelRequestParameters, ) -> AsyncIterator[StreamedResponse]: check_allow_model_requests() response = await self._completions_create( messages, True, cast(GroqModelSettings, model_settings or {}), model_request_parameters ) async with response: yield await self._process_streamed_response(response) @property def model_name(self) -> GroqModelName: """The model name.""" return self._model_name @property def system(self) -> str: """The system / model provider.""" return self._system @overload async def _completions_create( self, messages: list[ModelMessage], stream: Literal[True], model_settings: GroqModelSettings, model_request_parameters: ModelRequestParameters, ) -> AsyncStream[chat.ChatCompletionChunk]: pass @overload async def _completions_create( self, messages: list[ModelMessage], stream: Literal[False], model_settings: GroqModelSettings, model_request_parameters: ModelRequestParameters, ) -> chat.ChatCompletion: pass async def _completions_create( self, messages: list[ModelMessage], stream: bool, model_settings: GroqModelSettings, model_request_parameters: ModelRequestParameters, ) -> chat.ChatCompletion | AsyncStream[chat.ChatCompletionChunk]: tools = self._get_tools(model_request_parameters) # standalone function to make it easier to override if not tools: tool_choice: Literal['none', 'required', 'auto'] | None = None elif not model_request_parameters.allow_text_output: tool_choice = 'required' else: tool_choice = 'auto' groq_messages = self._map_messages(messages) try: extra_headers = model_settings.get('extra_headers', {}) extra_headers.setdefault('User-Agent', get_user_agent()) return await self.client.chat.completions.create( model=str(self._model_name), messages=groq_messages, n=1, parallel_tool_calls=model_settings.get('parallel_tool_calls', NOT_GIVEN), tools=tools or NOT_GIVEN, tool_choice=tool_choice or NOT_GIVEN, stop=model_settings.get('stop_sequences', NOT_GIVEN), stream=stream, max_tokens=model_settings.get('max_tokens', NOT_GIVEN), temperature=model_settings.get('temperature', NOT_GIVEN), top_p=model_settings.get('top_p', NOT_GIVEN), timeout=model_settings.get('timeout', NOT_GIVEN), seed=model_settings.get('seed', NOT_GIVEN), presence_penalty=model_settings.get('presence_penalty', NOT_GIVEN), reasoning_format=model_settings.get('groq_reasoning_format', NOT_GIVEN), frequency_penalty=model_settings.get('frequency_penalty', NOT_GIVEN), logit_bias=model_settings.get('logit_bias', NOT_GIVEN), extra_headers=extra_headers, extra_body=model_settings.get('extra_body'), ) except APIStatusError as e: if (status_code := e.status_code) >= 400: raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=e.body) from e raise # pragma: lax no cover def _process_response(self, response: chat.ChatCompletion) -> ModelResponse: """Process a non-streamed response, and prepare a message to return.""" timestamp = number_to_datetime(response.created) choice = response.choices[0] items: list[ModelResponsePart] = [] # NOTE: The `reasoning` field is only present if `groq_reasoning_format` is set to `parsed`. if choice.message.reasoning is not None: items.append(ThinkingPart(content=choice.message.reasoning)) if choice.message.content is not None: # NOTE: The `` tag is only present if `groq_reasoning_format` is set to `raw`. items.extend(split_content_into_text_and_thinking(choice.message.content)) if choice.message.tool_calls is not None: for c in choice.message.tool_calls: items.append(ToolCallPart(tool_name=c.function.name, args=c.function.arguments, tool_call_id=c.id)) return ModelResponse( items, usage=_map_usage(response), model_name=response.model, timestamp=timestamp, vendor_id=response.id ) async def _process_streamed_response(self, response: AsyncStream[chat.ChatCompletionChunk]) -> GroqStreamedResponse: """Process a streamed response, and prepare a streaming response to return.""" peekable_response = _utils.PeekableAsyncStream(response) first_chunk = await peekable_response.peek() if isinstance(first_chunk, _utils.Unset): raise UnexpectedModelBehavior( # pragma: no cover 'Streamed response ended without content or tool calls' ) return GroqStreamedResponse( _response=peekable_response, _model_name=self._model_name, _timestamp=number_to_datetime(first_chunk.created), ) def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[chat.ChatCompletionToolParam]: tools = [self._map_tool_definition(r) for r in model_request_parameters.function_tools] if model_request_parameters.output_tools: tools += [self._map_tool_definition(r) for r in model_request_parameters.output_tools] return tools def _map_messages(self, messages: list[ModelMessage]) -> list[chat.ChatCompletionMessageParam]: """Just maps a `pydantic_ai.Message` to a `groq.types.ChatCompletionMessageParam`.""" groq_messages: list[chat.ChatCompletionMessageParam] = [] for message in messages: if isinstance(message, ModelRequest): groq_messages.extend(self._map_user_message(message)) elif isinstance(message, ModelResponse): texts: list[str] = [] tool_calls: list[chat.ChatCompletionMessageToolCallParam] = [] for item in message.parts: if isinstance(item, TextPart): texts.append(item.content) elif isinstance(item, ToolCallPart): tool_calls.append(self._map_tool_call(item)) elif isinstance(item, ThinkingPart): # Skip thinking parts when mapping to Groq messages continue else: assert_never(item) message_param = chat.ChatCompletionAssistantMessageParam(role='assistant') if texts: # Note: model responses from this model should only have one text item, so the following # shouldn't merge multiple texts into one unless you switch models between runs: message_param['content'] = '\n\n'.join(texts) if tool_calls: message_param['tool_calls'] = tool_calls groq_messages.append(message_param) else: assert_never(message) if instructions := self._get_instructions(messages): groq_messages.insert(0, chat.ChatCompletionSystemMessageParam(role='system', content=instructions)) return groq_messages @staticmethod def _map_tool_call(t: ToolCallPart) -> chat.ChatCompletionMessageToolCallParam: return chat.ChatCompletionMessageToolCallParam( id=_guard_tool_call_id(t=t), type='function', function={'name': t.tool_name, 'arguments': t.args_as_json_str()}, ) @staticmethod def _map_tool_definition(f: ToolDefinition) -> chat.ChatCompletionToolParam: return { 'type': 'function', 'function': { 'name': f.name, 'description': f.description, 'parameters': f.parameters_json_schema, }, } @classmethod def _map_user_message(cls, message: ModelRequest) -> Iterable[chat.ChatCompletionMessageParam]: for part in message.parts: if isinstance(part, SystemPromptPart): yield chat.ChatCompletionSystemMessageParam(role='system', content=part.content) elif isinstance(part, UserPromptPart): yield cls._map_user_prompt(part) elif isinstance(part, ToolReturnPart): yield chat.ChatCompletionToolMessageParam( role='tool', tool_call_id=_guard_tool_call_id(t=part), content=part.model_response_str(), ) elif isinstance(part, RetryPromptPart): # pragma: no branch if part.tool_name is None: yield chat.ChatCompletionUserMessageParam( # pragma: no cover role='user', content=part.model_response() ) else: yield chat.ChatCompletionToolMessageParam( role='tool', tool_call_id=_guard_tool_call_id(t=part), content=part.model_response(), ) @staticmethod def _map_user_prompt(part: UserPromptPart) -> chat.ChatCompletionUserMessageParam: content: str | list[chat.ChatCompletionContentPartParam] if isinstance(part.content, str): content = part.content else: content = [] for item in part.content: if isinstance(item, str): content.append(chat.ChatCompletionContentPartTextParam(text=item, type='text')) elif isinstance(item, ImageUrl): image_url = ImageURL(url=item.url) content.append(chat.ChatCompletionContentPartImageParam(image_url=image_url, type='image_url')) elif isinstance(item, BinaryContent): base64_encoded = base64.b64encode(item.data).decode('utf-8') if item.is_image: image_url = ImageURL(url=f'data:{item.media_type};base64,{base64_encoded}') content.append(chat.ChatCompletionContentPartImageParam(image_url=image_url, type='image_url')) else: raise RuntimeError('Only images are supported for binary content in Groq.') elif isinstance(item, DocumentUrl): # pragma: no cover raise RuntimeError('DocumentUrl is not supported in Groq.') else: # pragma: no cover raise RuntimeError(f'Unsupported content type: {type(item)}') return chat.ChatCompletionUserMessageParam(role='user', content=content) ``` #### __init__ ```python __init__( model_name: GroqModelName, *, provider: ( Literal["groq"] | Provider[AsyncGroq] ) = "groq", profile: ModelProfileSpec | None = None ) ``` Initialize a Groq model. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `model_name` | `GroqModelName` | The name of the Groq model to use. List of model names available here. | *required* | | `provider` | `Literal['groq'] | Provider[AsyncGroq]` | The provider to use for authentication and API access. Can be either the string 'groq' or an instance of Provider[AsyncGroq]. If not provided, a new provider will be created using the other parameters. | `'groq'` | | `profile` | `ModelProfileSpec | None` | The model profile to use. Defaults to a profile picked by the provider based on the model name. | `None` | Source code in `pydantic_ai_slim/pydantic_ai/models/groq.py` ```python def __init__( self, model_name: GroqModelName, *, provider: Literal['groq'] | Provider[AsyncGroq] = 'groq', profile: ModelProfileSpec | None = None, ): """Initialize a Groq model. Args: model_name: The name of the Groq model to use. List of model names available [here](https://console.groq.com/docs/models). provider: The provider to use for authentication and API access. Can be either the string 'groq' or an instance of `Provider[AsyncGroq]`. If not provided, a new provider will be created using the other parameters. profile: The model profile to use. Defaults to a profile picked by the provider based on the model name. """ self._model_name = model_name if isinstance(provider, str): provider = infer_provider(provider) self.client = provider.client self._profile = profile or provider.model_profile ``` #### model_name ```python model_name: GroqModelName ``` The model name. #### system ```python system: str ``` The system / model provider. ### GroqStreamedResponse Bases: `StreamedResponse` Implementation of `StreamedResponse` for Groq models. Source code in `pydantic_ai_slim/pydantic_ai/models/groq.py` ```python @dataclass class GroqStreamedResponse(StreamedResponse): """Implementation of `StreamedResponse` for Groq models.""" _model_name: GroqModelName _response: AsyncIterable[chat.ChatCompletionChunk] _timestamp: datetime async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: async for chunk in self._response: self._usage += _map_usage(chunk) try: choice = chunk.choices[0] except IndexError: continue # Handle the text part of the response content = choice.delta.content if content is not None: yield self._parts_manager.handle_text_delta(vendor_part_id='content', content=content) # Handle the tool calls for dtc in choice.delta.tool_calls or []: maybe_event = self._parts_manager.handle_tool_call_delta( vendor_part_id=dtc.index, tool_name=dtc.function and dtc.function.name, args=dtc.function and dtc.function.arguments, tool_call_id=dtc.id, ) if maybe_event is not None: yield maybe_event @property def model_name(self) -> GroqModelName: """Get the model name of the response.""" return self._model_name @property def timestamp(self) -> datetime: """Get the timestamp of the response.""" return self._timestamp ``` #### model_name ```python model_name: GroqModelName ``` Get the model name of the response. #### timestamp ```python timestamp: datetime ``` Get the timestamp of the response. # pydantic_ai.models.instrumented ### instrument_model ```python instrument_model( model: Model, instrument: InstrumentationSettings | bool ) -> Model ``` Instrument a model with OpenTelemetry/logfire. Source code in `pydantic_ai_slim/pydantic_ai/models/instrumented.py` ```python def instrument_model(model: Model, instrument: InstrumentationSettings | bool) -> Model: """Instrument a model with OpenTelemetry/logfire.""" if instrument and not isinstance(model, InstrumentedModel): if instrument is True: instrument = InstrumentationSettings() model = InstrumentedModel(model, instrument) return model ``` ### InstrumentationSettings Options for instrumenting models and agents with OpenTelemetry. Used in: - `Agent(instrument=...)` - Agent.instrument_all() - InstrumentedModel See the [Debugging and Monitoring guide](https://ai.pydantic.dev/logfire/) for more info. Source code in `pydantic_ai_slim/pydantic_ai/models/instrumented.py` ```python @dataclass(init=False) class InstrumentationSettings: """Options for instrumenting models and agents with OpenTelemetry. Used in: - `Agent(instrument=...)` - [`Agent.instrument_all()`][pydantic_ai.agent.Agent.instrument_all] - [`InstrumentedModel`][pydantic_ai.models.instrumented.InstrumentedModel] See the [Debugging and Monitoring guide](https://ai.pydantic.dev/logfire/) for more info. """ tracer: Tracer = field(repr=False) event_logger: EventLogger = field(repr=False) event_mode: Literal['attributes', 'logs'] = 'attributes' include_binary_content: bool = True def __init__( self, *, event_mode: Literal['attributes', 'logs'] = 'attributes', tracer_provider: TracerProvider | None = None, meter_provider: MeterProvider | None = None, event_logger_provider: EventLoggerProvider | None = None, include_binary_content: bool = True, include_content: bool = True, ): """Create instrumentation options. Args: event_mode: The mode for emitting events. If `'attributes'`, events are attached to the span as attributes. If `'logs'`, events are emitted as OpenTelemetry log-based events. tracer_provider: The OpenTelemetry tracer provider to use. If not provided, the global tracer provider is used. Calling `logfire.configure()` sets the global tracer provider, so most users don't need this. meter_provider: The OpenTelemetry meter provider to use. If not provided, the global meter provider is used. Calling `logfire.configure()` sets the global meter provider, so most users don't need this. event_logger_provider: The OpenTelemetry event logger provider to use. If not provided, the global event logger provider is used. Calling `logfire.configure()` sets the global event logger provider, so most users don't need this. This is only used if `event_mode='logs'`. include_binary_content: Whether to include binary content in the instrumentation events. include_content: Whether to include prompts, completions, and tool call arguments and responses in the instrumentation events. """ from pydantic_ai import __version__ tracer_provider = tracer_provider or get_tracer_provider() meter_provider = meter_provider or get_meter_provider() event_logger_provider = event_logger_provider or get_event_logger_provider() scope_name = 'pydantic-ai' self.tracer = tracer_provider.get_tracer(scope_name, __version__) self.meter = meter_provider.get_meter(scope_name, __version__) self.event_logger = event_logger_provider.get_event_logger(scope_name, __version__) self.event_mode = event_mode self.include_binary_content = include_binary_content self.include_content = include_content # As specified in the OpenTelemetry GenAI metrics spec: # https://opentelemetry.io/docs/specs/semconv/gen-ai/gen-ai-metrics/#metric-gen_aiclienttokenusage tokens_histogram_kwargs = dict( name='gen_ai.client.token.usage', unit='{token}', description='Measures number of input and output tokens used', ) try: self.tokens_histogram = self.meter.create_histogram( **tokens_histogram_kwargs, explicit_bucket_boundaries_advisory=TOKEN_HISTOGRAM_BOUNDARIES, ) except TypeError: # pragma: lax no cover # Older OTel/logfire versions don't support explicit_bucket_boundaries_advisory self.tokens_histogram = self.meter.create_histogram( **tokens_histogram_kwargs, # pyright: ignore ) def messages_to_otel_events(self, messages: list[ModelMessage]) -> list[Event]: """Convert a list of model messages to OpenTelemetry events. Args: messages: The messages to convert. Returns: A list of OpenTelemetry events. """ events: list[Event] = [] instructions = InstrumentedModel._get_instructions(messages) # pyright: ignore [reportPrivateUsage] if instructions is not None: events.append(Event('gen_ai.system.message', body={'content': instructions, 'role': 'system'})) for message_index, message in enumerate(messages): message_events: list[Event] = [] if isinstance(message, ModelRequest): for part in message.parts: if hasattr(part, 'otel_event'): message_events.append(part.otel_event(self)) elif isinstance(message, ModelResponse): # pragma: no branch message_events = message.otel_events(self) for event in message_events: event.attributes = { 'gen_ai.message.index': message_index, **(event.attributes or {}), } events.extend(message_events) for event in events: event.body = InstrumentedModel.serialize_any(event.body) return events ``` #### __init__ ```python __init__( *, event_mode: Literal[ "attributes", "logs" ] = "attributes", tracer_provider: TracerProvider | None = None, meter_provider: MeterProvider | None = None, event_logger_provider: ( EventLoggerProvider | None ) = None, include_binary_content: bool = True, include_content: bool = True ) ``` Create instrumentation options. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `event_mode` | `Literal['attributes', 'logs']` | The mode for emitting events. If 'attributes', events are attached to the span as attributes. If 'logs', events are emitted as OpenTelemetry log-based events. | `'attributes'` | | `tracer_provider` | `TracerProvider | None` | The OpenTelemetry tracer provider to use. If not provided, the global tracer provider is used. Calling logfire.configure() sets the global tracer provider, so most users don't need this. | `None` | | `meter_provider` | `MeterProvider | None` | The OpenTelemetry meter provider to use. If not provided, the global meter provider is used. Calling logfire.configure() sets the global meter provider, so most users don't need this. | `None` | | `event_logger_provider` | `EventLoggerProvider | None` | The OpenTelemetry event logger provider to use. If not provided, the global event logger provider is used. Calling logfire.configure() sets the global event logger provider, so most users don't need this. This is only used if event_mode='logs'. | `None` | | `include_binary_content` | `bool` | Whether to include binary content in the instrumentation events. | `True` | | `include_content` | `bool` | Whether to include prompts, completions, and tool call arguments and responses in the instrumentation events. | `True` | Source code in `pydantic_ai_slim/pydantic_ai/models/instrumented.py` ```python def __init__( self, *, event_mode: Literal['attributes', 'logs'] = 'attributes', tracer_provider: TracerProvider | None = None, meter_provider: MeterProvider | None = None, event_logger_provider: EventLoggerProvider | None = None, include_binary_content: bool = True, include_content: bool = True, ): """Create instrumentation options. Args: event_mode: The mode for emitting events. If `'attributes'`, events are attached to the span as attributes. If `'logs'`, events are emitted as OpenTelemetry log-based events. tracer_provider: The OpenTelemetry tracer provider to use. If not provided, the global tracer provider is used. Calling `logfire.configure()` sets the global tracer provider, so most users don't need this. meter_provider: The OpenTelemetry meter provider to use. If not provided, the global meter provider is used. Calling `logfire.configure()` sets the global meter provider, so most users don't need this. event_logger_provider: The OpenTelemetry event logger provider to use. If not provided, the global event logger provider is used. Calling `logfire.configure()` sets the global event logger provider, so most users don't need this. This is only used if `event_mode='logs'`. include_binary_content: Whether to include binary content in the instrumentation events. include_content: Whether to include prompts, completions, and tool call arguments and responses in the instrumentation events. """ from pydantic_ai import __version__ tracer_provider = tracer_provider or get_tracer_provider() meter_provider = meter_provider or get_meter_provider() event_logger_provider = event_logger_provider or get_event_logger_provider() scope_name = 'pydantic-ai' self.tracer = tracer_provider.get_tracer(scope_name, __version__) self.meter = meter_provider.get_meter(scope_name, __version__) self.event_logger = event_logger_provider.get_event_logger(scope_name, __version__) self.event_mode = event_mode self.include_binary_content = include_binary_content self.include_content = include_content # As specified in the OpenTelemetry GenAI metrics spec: # https://opentelemetry.io/docs/specs/semconv/gen-ai/gen-ai-metrics/#metric-gen_aiclienttokenusage tokens_histogram_kwargs = dict( name='gen_ai.client.token.usage', unit='{token}', description='Measures number of input and output tokens used', ) try: self.tokens_histogram = self.meter.create_histogram( **tokens_histogram_kwargs, explicit_bucket_boundaries_advisory=TOKEN_HISTOGRAM_BOUNDARIES, ) except TypeError: # pragma: lax no cover # Older OTel/logfire versions don't support explicit_bucket_boundaries_advisory self.tokens_histogram = self.meter.create_histogram( **tokens_histogram_kwargs, # pyright: ignore ) ``` #### messages_to_otel_events ```python messages_to_otel_events( messages: list[ModelMessage], ) -> list[Event] ``` Convert a list of model messages to OpenTelemetry events. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `messages` | `list[ModelMessage]` | The messages to convert. | *required* | Returns: | Type | Description | | --- | --- | | `list[Event]` | A list of OpenTelemetry events. | Source code in `pydantic_ai_slim/pydantic_ai/models/instrumented.py` ```python def messages_to_otel_events(self, messages: list[ModelMessage]) -> list[Event]: """Convert a list of model messages to OpenTelemetry events. Args: messages: The messages to convert. Returns: A list of OpenTelemetry events. """ events: list[Event] = [] instructions = InstrumentedModel._get_instructions(messages) # pyright: ignore [reportPrivateUsage] if instructions is not None: events.append(Event('gen_ai.system.message', body={'content': instructions, 'role': 'system'})) for message_index, message in enumerate(messages): message_events: list[Event] = [] if isinstance(message, ModelRequest): for part in message.parts: if hasattr(part, 'otel_event'): message_events.append(part.otel_event(self)) elif isinstance(message, ModelResponse): # pragma: no branch message_events = message.otel_events(self) for event in message_events: event.attributes = { 'gen_ai.message.index': message_index, **(event.attributes or {}), } events.extend(message_events) for event in events: event.body = InstrumentedModel.serialize_any(event.body) return events ``` ### InstrumentedModel Bases: `WrapperModel` Model which wraps another model so that requests are instrumented with OpenTelemetry. See the [Debugging and Monitoring guide](https://ai.pydantic.dev/logfire/) for more info. Source code in `pydantic_ai_slim/pydantic_ai/models/instrumented.py` ```python @dataclass class InstrumentedModel(WrapperModel): """Model which wraps another model so that requests are instrumented with OpenTelemetry. See the [Debugging and Monitoring guide](https://ai.pydantic.dev/logfire/) for more info. """ settings: InstrumentationSettings """Configuration for instrumenting requests.""" def __init__( self, wrapped: Model | KnownModelName, options: InstrumentationSettings | None = None, ) -> None: super().__init__(wrapped) self.settings = options or InstrumentationSettings() async def request( self, messages: list[ModelMessage], model_settings: ModelSettings | None, model_request_parameters: ModelRequestParameters, ) -> ModelResponse: with self._instrument(messages, model_settings, model_request_parameters) as finish: response = await super().request(messages, model_settings, model_request_parameters) finish(response) return response @asynccontextmanager async def request_stream( self, messages: list[ModelMessage], model_settings: ModelSettings | None, model_request_parameters: ModelRequestParameters, ) -> AsyncIterator[StreamedResponse]: with self._instrument(messages, model_settings, model_request_parameters) as finish: response_stream: StreamedResponse | None = None try: async with super().request_stream( messages, model_settings, model_request_parameters ) as response_stream: yield response_stream finally: if response_stream: # pragma: no branch finish(response_stream.get()) @contextmanager def _instrument( self, messages: list[ModelMessage], model_settings: ModelSettings | None, model_request_parameters: ModelRequestParameters, ) -> Iterator[Callable[[ModelResponse], None]]: operation = 'chat' span_name = f'{operation} {self.model_name}' # TODO Missing attributes: # - error.type: unclear if we should do something here or just always rely on span exceptions # - gen_ai.request.stop_sequences/top_k: model_settings doesn't include these attributes: dict[str, AttributeValue] = { 'gen_ai.operation.name': operation, **self.model_attributes(self.wrapped), 'model_request_parameters': json.dumps(InstrumentedModel.serialize_any(model_request_parameters)), 'logfire.json_schema': json.dumps( { 'type': 'object', 'properties': {'model_request_parameters': {'type': 'object'}}, } ), } if model_settings: for key in MODEL_SETTING_ATTRIBUTES: if isinstance(value := model_settings.get(key), (float, int)): attributes[f'gen_ai.request.{key}'] = value record_metrics: Callable[[], None] | None = None try: with self.settings.tracer.start_as_current_span(span_name, attributes=attributes) as span: def finish(response: ModelResponse): # FallbackModel updates these span attributes. attributes.update(getattr(span, 'attributes', {})) request_model = attributes[GEN_AI_REQUEST_MODEL_ATTRIBUTE] system = attributes[GEN_AI_SYSTEM_ATTRIBUTE] response_model = response.model_name or request_model def _record_metrics(): metric_attributes = { GEN_AI_SYSTEM_ATTRIBUTE: system, 'gen_ai.operation.name': operation, 'gen_ai.request.model': request_model, 'gen_ai.response.model': response_model, } if response.usage.request_tokens: # pragma: no branch self.settings.tokens_histogram.record( response.usage.request_tokens, {**metric_attributes, 'gen_ai.token.type': 'input'}, ) if response.usage.response_tokens: # pragma: no branch self.settings.tokens_histogram.record( response.usage.response_tokens, {**metric_attributes, 'gen_ai.token.type': 'output'}, ) nonlocal record_metrics record_metrics = _record_metrics if not span.is_recording(): return events = self.settings.messages_to_otel_events(messages) for event in self.settings.messages_to_otel_events([response]): events.append( Event( 'gen_ai.choice', body={ # TODO finish_reason 'index': 0, 'message': event.body, }, ) ) span.set_attributes( { **response.usage.opentelemetry_attributes(), 'gen_ai.response.model': response_model, } ) span.update_name(f'{operation} {request_model}') for event in events: event.attributes = { GEN_AI_SYSTEM_ATTRIBUTE: system, **(event.attributes or {}), } self._emit_events(span, events) yield finish finally: if record_metrics: # We only want to record metrics after the span is finished, # to prevent them from being redundantly recorded in the span itself by logfire. record_metrics() def _emit_events(self, span: Span, events: list[Event]) -> None: if self.settings.event_mode == 'logs': for event in events: self.settings.event_logger.emit(event) else: attr_name = 'events' span.set_attributes( { attr_name: json.dumps([self.event_to_dict(event) for event in events]), 'logfire.json_schema': json.dumps( { 'type': 'object', 'properties': { attr_name: {'type': 'array'}, 'model_request_parameters': {'type': 'object'}, }, } ), } ) @staticmethod def model_attributes(model: Model): attributes: dict[str, AttributeValue] = { GEN_AI_SYSTEM_ATTRIBUTE: model.system, GEN_AI_REQUEST_MODEL_ATTRIBUTE: model.model_name, } if base_url := model.base_url: try: parsed = urlparse(base_url) except Exception: # pragma: no cover pass else: if parsed.hostname: # pragma: no branch attributes['server.address'] = parsed.hostname if parsed.port: # pragma: no branch attributes['server.port'] = parsed.port return attributes @staticmethod def event_to_dict(event: Event) -> dict[str, Any]: if not event.body: body = {} # pragma: no cover elif isinstance(event.body, Mapping): body = event.body # type: ignore else: body = {'body': event.body} return {**body, **(event.attributes or {})} @staticmethod def serialize_any(value: Any) -> str: try: return ANY_ADAPTER.dump_python(value, mode='json') except Exception: try: return str(value) except Exception as e: return f'Unable to serialize: {e}' ``` #### settings ```python settings: InstrumentationSettings = ( options or InstrumentationSettings() ) ``` Configuration for instrumenting requests. # pydantic_ai.models.mcp_sampling ### MCPSamplingModelSettings Bases: `ModelSettings` Settings used for an MCP Sampling model request. Source code in `pydantic_ai_slim/pydantic_ai/models/mcp_sampling.py` ```python class MCPSamplingModelSettings(ModelSettings, total=False): """Settings used for an MCP Sampling model request.""" # ALL FIELDS MUST BE `mcp_` PREFIXED SO YOU CAN MERGE THEM WITH OTHER MODELS. mcp_model_preferences: ModelPreferences """Model preferences to use for MCP Sampling.""" ``` #### mcp_model_preferences ```python mcp_model_preferences: ModelPreferences ``` Model preferences to use for MCP Sampling. ### MCPSamplingModel Bases: `Model` A model that uses MCP Sampling. [MCP Sampling](https://modelcontextprotocol.io/docs/concepts/sampling) allows an MCP server to make requests to a model by calling back to the MCP client that connected to it. Source code in `pydantic_ai_slim/pydantic_ai/models/mcp_sampling.py` ```python @dataclass class MCPSamplingModel(Model): """A model that uses MCP Sampling. [MCP Sampling](https://modelcontextprotocol.io/docs/concepts/sampling) allows an MCP server to make requests to a model by calling back to the MCP client that connected to it. """ session: ServerSession """The MCP server session to use for sampling.""" default_max_tokens: int = 16_384 """Default max tokens to use if not set in [`ModelSettings`][pydantic_ai.settings.ModelSettings.max_tokens]. Max tokens is a required parameter for MCP Sampling, but optional on [`ModelSettings`][pydantic_ai.settings.ModelSettings], so this value is used as fallback. """ async def request( self, messages: list[ModelMessage], model_settings: ModelSettings | None, model_request_parameters: ModelRequestParameters, ) -> ModelResponse: system_prompt, sampling_messages = _mcp.map_from_pai_messages(messages) model_settings = cast(MCPSamplingModelSettings, model_settings or {}) result = await self.session.create_message( sampling_messages, max_tokens=model_settings.get('max_tokens', self.default_max_tokens), system_prompt=system_prompt, temperature=model_settings.get('temperature'), model_preferences=model_settings.get('mcp_model_preferences'), stop_sequences=model_settings.get('stop_sequences'), ) if result.role == 'assistant': return ModelResponse( parts=[_mcp.map_from_sampling_content(result.content)], usage=usage.Usage(requests=1), model_name=result.model, ) else: raise exceptions.UnexpectedModelBehavior( f'Unexpected result from MCP sampling, expected "assistant" role, got {result.role}.' ) @asynccontextmanager async def request_stream( self, messages: list[ModelMessage], model_settings: ModelSettings | None, model_request_parameters: ModelRequestParameters, ) -> AsyncIterator[StreamedResponse]: raise NotImplementedError('MCP Sampling does not support streaming') yield @property def model_name(self) -> str: """The model name. Since the model name isn't known until the request is made, this property always returns `'mcp-sampling'`. """ return 'mcp-sampling' @property def system(self) -> str: """The system / model provider, returns `'MCP'`.""" return 'MCP' ``` #### session ```python session: ServerSession ``` The MCP server session to use for sampling. #### default_max_tokens ```python default_max_tokens: int = 16384 ``` Default max tokens to use if not set in ModelSettings. Max tokens is a required parameter for MCP Sampling, but optional on ModelSettings, so this value is used as fallback. #### model_name ```python model_name: str ``` The model name. Since the model name isn't known until the request is made, this property always returns `'mcp-sampling'`. #### system ```python system: str ``` The system / model provider, returns `'MCP'`. # `pydantic_ai.models.mistral` ## Setup For details on how to set up authentication with this model, see [model configuration for Mistral](../../../models/mistral/). ### LatestMistralModelNames ```python LatestMistralModelNames = Literal[ "mistral-large-latest", "mistral-small-latest", "codestral-latest", "mistral-moderation-latest", ] ``` Latest Mistral models. ### MistralModelName ```python MistralModelName = Union[str, LatestMistralModelNames] ``` Possible Mistral model names. Since Mistral supports a variety of date-stamped models, we explicitly list the most popular models but allow any name in the type hints. Since [the Mistral docs](https://docs.mistral.ai/getting-started/models/models_overview/) for a full list. ### MistralModelSettings Bases: `ModelSettings` Settings used for a Mistral model request. Source code in `pydantic_ai_slim/pydantic_ai/models/mistral.py` ```python class MistralModelSettings(ModelSettings, total=False): """Settings used for a Mistral model request.""" ``` ### MistralModel Bases: `Model` A model that uses Mistral. Internally, this uses the [Mistral Python client](https://github.com/mistralai/client-python) to interact with the API. [API Documentation](https://docs.mistral.ai/) Source code in `pydantic_ai_slim/pydantic_ai/models/mistral.py` ````python @dataclass(init=False) class MistralModel(Model): """A model that uses Mistral. Internally, this uses the [Mistral Python client](https://github.com/mistralai/client-python) to interact with the API. [API Documentation](https://docs.mistral.ai/) """ client: Mistral = field(repr=False) json_mode_schema_prompt: str = """Answer in JSON Object, respect the format:\n```\n{schema}\n```\n""" _model_name: MistralModelName = field(repr=False) _system: str = field(default='mistral_ai', repr=False) def __init__( self, model_name: MistralModelName, *, provider: Literal['mistral'] | Provider[Mistral] = 'mistral', profile: ModelProfileSpec | None = None, json_mode_schema_prompt: str = """Answer in JSON Object, respect the format:\n```\n{schema}\n```\n""", ): """Initialize a Mistral model. Args: model_name: The name of the model to use. provider: The provider to use for authentication and API access. Can be either the string 'mistral' or an instance of `Provider[Mistral]`. If not provided, a new provider will be created using the other parameters. profile: The model profile to use. Defaults to a profile picked by the provider based on the model name. json_mode_schema_prompt: The prompt to show when the model expects a JSON object as input. """ self._model_name = model_name self.json_mode_schema_prompt = json_mode_schema_prompt if isinstance(provider, str): provider = infer_provider(provider) self.client = provider.client self._profile = profile or provider.model_profile @property def base_url(self) -> str: return self.client.sdk_configuration.get_server_details()[0] async def request( self, messages: list[ModelMessage], model_settings: ModelSettings | None, model_request_parameters: ModelRequestParameters, ) -> ModelResponse: """Make a non-streaming request to the model from Pydantic AI call.""" check_allow_model_requests() response = await self._completions_create( messages, cast(MistralModelSettings, model_settings or {}), model_request_parameters ) model_response = self._process_response(response) model_response.usage.requests = 1 return model_response @asynccontextmanager async def request_stream( self, messages: list[ModelMessage], model_settings: ModelSettings | None, model_request_parameters: ModelRequestParameters, ) -> AsyncIterator[StreamedResponse]: """Make a streaming request to the model from Pydantic AI call.""" check_allow_model_requests() response = await self._stream_completions_create( messages, cast(MistralModelSettings, model_settings or {}), model_request_parameters ) async with response: yield await self._process_streamed_response(model_request_parameters.output_tools, response) @property def model_name(self) -> MistralModelName: """The model name.""" return self._model_name @property def system(self) -> str: """The system / model provider.""" return self._system async def _completions_create( self, messages: list[ModelMessage], model_settings: MistralModelSettings, model_request_parameters: ModelRequestParameters, ) -> MistralChatCompletionResponse: """Make a non-streaming request to the model.""" try: response = await self.client.chat.complete_async( model=str(self._model_name), messages=self._map_messages(messages), n=1, tools=self._map_function_and_output_tools_definition(model_request_parameters) or UNSET, tool_choice=self._get_tool_choice(model_request_parameters), stream=False, max_tokens=model_settings.get('max_tokens', UNSET), temperature=model_settings.get('temperature', UNSET), top_p=model_settings.get('top_p', 1), timeout_ms=self._get_timeout_ms(model_settings.get('timeout')), random_seed=model_settings.get('seed', UNSET), stop=model_settings.get('stop_sequences', None), http_headers={'User-Agent': get_user_agent()}, ) except SDKError as e: if (status_code := e.status_code) >= 400: raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=e.body) from e raise # pragma: lax no cover assert response, 'A unexpected empty response from Mistral.' return response async def _stream_completions_create( self, messages: list[ModelMessage], model_settings: MistralModelSettings, model_request_parameters: ModelRequestParameters, ) -> MistralEventStreamAsync[MistralCompletionEvent]: """Create a streaming completion request to the Mistral model.""" response: MistralEventStreamAsync[MistralCompletionEvent] | None mistral_messages = self._map_messages(messages) if ( model_request_parameters.output_tools and model_request_parameters.function_tools or model_request_parameters.function_tools ): # Function Calling response = await self.client.chat.stream_async( model=str(self._model_name), messages=mistral_messages, n=1, tools=self._map_function_and_output_tools_definition(model_request_parameters) or UNSET, tool_choice=self._get_tool_choice(model_request_parameters), temperature=model_settings.get('temperature', UNSET), top_p=model_settings.get('top_p', 1), max_tokens=model_settings.get('max_tokens', UNSET), timeout_ms=self._get_timeout_ms(model_settings.get('timeout')), presence_penalty=model_settings.get('presence_penalty'), frequency_penalty=model_settings.get('frequency_penalty'), stop=model_settings.get('stop_sequences', None), http_headers={'User-Agent': get_user_agent()}, ) elif model_request_parameters.output_tools: # TODO: Port to native "manual JSON" mode # Json Mode parameters_json_schemas = [tool.parameters_json_schema for tool in model_request_parameters.output_tools] user_output_format_message = self._generate_user_output_format(parameters_json_schemas) mistral_messages.append(user_output_format_message) response = await self.client.chat.stream_async( model=str(self._model_name), messages=mistral_messages, response_format={ 'type': 'json_object' }, # TODO: Should be able to use json_schema now: https://docs.mistral.ai/capabilities/structured-output/custom_structured_output/, https://github.com/mistralai/client-python/blob/bc4adf335968c8a272e1ab7da8461c9943d8e701/src/mistralai/extra/utils/response_format.py#L9 stream=True, http_headers={'User-Agent': get_user_agent()}, ) else: # Stream Mode response = await self.client.chat.stream_async( model=str(self._model_name), messages=mistral_messages, stream=True, http_headers={'User-Agent': get_user_agent()}, ) assert response, 'A unexpected empty response from Mistral.' return response def _get_tool_choice(self, model_request_parameters: ModelRequestParameters) -> MistralToolChoiceEnum | None: """Get tool choice for the model. - "auto": Default mode. Model decides if it uses the tool or not. - "any": Select any tool. - "none": Prevents tool use. - "required": Forces tool use. """ if not model_request_parameters.function_tools and not model_request_parameters.output_tools: return None elif not model_request_parameters.allow_text_output: return 'required' else: return 'auto' def _map_function_and_output_tools_definition( self, model_request_parameters: ModelRequestParameters ) -> list[MistralTool] | None: """Map function and output tools to MistralTool format. Returns None if both function_tools and output_tools are empty. """ all_tools: list[ToolDefinition] = ( model_request_parameters.function_tools + model_request_parameters.output_tools ) tools = [ MistralTool( function=MistralFunction(name=r.name, parameters=r.parameters_json_schema, description=r.description) ) for r in all_tools ] return tools if tools else None def _process_response(self, response: MistralChatCompletionResponse) -> ModelResponse: """Process a non-streamed response, and prepare a message to return.""" assert response.choices, 'Unexpected empty response choice.' if response.created: timestamp = number_to_datetime(response.created) else: timestamp = _now_utc() choice = response.choices[0] content = choice.message.content tool_calls = choice.message.tool_calls parts: list[ModelResponsePart] = [] if text := _map_content(content): parts.extend(split_content_into_text_and_thinking(text)) if isinstance(tool_calls, list): for tool_call in tool_calls: tool = self._map_mistral_to_pydantic_tool_call(tool_call=tool_call) parts.append(tool) return ModelResponse( parts, usage=_map_usage(response), model_name=response.model, timestamp=timestamp, vendor_id=response.id ) async def _process_streamed_response( self, output_tools: list[ToolDefinition], response: MistralEventStreamAsync[MistralCompletionEvent], ) -> StreamedResponse: """Process a streamed response, and prepare a streaming response to return.""" peekable_response = _utils.PeekableAsyncStream(response) first_chunk = await peekable_response.peek() if isinstance(first_chunk, _utils.Unset): raise UnexpectedModelBehavior( # pragma: no cover 'Streamed response ended without content or tool calls' ) if first_chunk.data.created: timestamp = number_to_datetime(first_chunk.data.created) else: timestamp = _now_utc() return MistralStreamedResponse( _response=peekable_response, _model_name=self._model_name, _timestamp=timestamp, _output_tools={c.name: c for c in output_tools}, ) @staticmethod def _map_mistral_to_pydantic_tool_call(tool_call: MistralToolCall) -> ToolCallPart: """Maps a MistralToolCall to a ToolCall.""" tool_call_id = tool_call.id or _generate_tool_call_id() func_call = tool_call.function return ToolCallPart(func_call.name, func_call.arguments, tool_call_id) @staticmethod def _map_tool_call(t: ToolCallPart) -> MistralToolCall: """Maps a pydantic-ai ToolCall to a MistralToolCall.""" return MistralToolCall( id=_utils.guard_tool_call_id(t=t), type='function', function=MistralFunctionCall(name=t.tool_name, arguments=t.args or {}), ) def _generate_user_output_format(self, schemas: list[dict[str, Any]]) -> MistralUserMessage: """Get a message with an example of the expected output format.""" examples: list[dict[str, Any]] = [] for schema in schemas: typed_dict_definition: dict[str, Any] = {} for key, value in schema.get('properties', {}).items(): typed_dict_definition[key] = self._get_python_type(value) examples.append(typed_dict_definition) example_schema = examples[0] if len(examples) == 1 else examples return MistralUserMessage(content=self.json_mode_schema_prompt.format(schema=example_schema)) @classmethod def _get_python_type(cls, value: dict[str, Any]) -> str: """Return a string representation of the Python type for a single JSON schema property. This function handles recursion for nested arrays/objects and `anyOf`. """ # 1) Handle anyOf first, because it's a different schema structure if any_of := value.get('anyOf'): # Simplistic approach: pick the first option in anyOf # (In reality, you'd possibly want to merge or union types) return f'Optional[{cls._get_python_type(any_of[0])}]' # 2) If we have a top-level "type" field value_type = value.get('type') if not value_type: # No explicit type; fallback return 'Any' # 3) Direct simple type mapping (string, integer, float, bool, None) if value_type in SIMPLE_JSON_TYPE_MAPPING and value_type != 'array' and value_type != 'object': return SIMPLE_JSON_TYPE_MAPPING[value_type] # 4) Array: Recursively get the item type if value_type == 'array': items = value.get('items', {}) return f'list[{cls._get_python_type(items)}]' # 5) Object: Check for additionalProperties if value_type == 'object': additional_properties = value.get('additionalProperties', {}) if isinstance(additional_properties, bool): return 'bool' # pragma: no cover additional_properties_type = additional_properties.get('type') if ( additional_properties_type in SIMPLE_JSON_TYPE_MAPPING and additional_properties_type != 'array' and additional_properties_type != 'object' ): # dict[str, bool/int/float/etc...] return f'dict[str, {SIMPLE_JSON_TYPE_MAPPING[additional_properties_type]}]' elif additional_properties_type == 'array': array_items = additional_properties.get('items', {}) return f'dict[str, list[{cls._get_python_type(array_items)}]]' elif additional_properties_type == 'object': # nested dictionary of unknown shape return 'dict[str, dict[str, Any]]' else: # If no additionalProperties type or something else, default to a generic dict return 'dict[str, Any]' # 6) Fallback return 'Any' @staticmethod def _get_timeout_ms(timeout: Timeout | float | None) -> int | None: """Convert a timeout to milliseconds.""" if timeout is None: return None if isinstance(timeout, float): # pragma: no cover return int(1000 * timeout) raise NotImplementedError('Timeout object is not yet supported for MistralModel.') def _map_user_message(self, message: ModelRequest) -> Iterable[MistralMessages]: for part in message.parts: if isinstance(part, SystemPromptPart): yield MistralSystemMessage(content=part.content) elif isinstance(part, UserPromptPart): yield self._map_user_prompt(part) elif isinstance(part, ToolReturnPart): yield MistralToolMessage( tool_call_id=part.tool_call_id, content=part.model_response_str(), ) elif isinstance(part, RetryPromptPart): if part.tool_name is None: yield MistralUserMessage(content=part.model_response()) # pragma: no cover else: yield MistralToolMessage( tool_call_id=part.tool_call_id, content=part.model_response(), ) else: assert_never(part) def _map_messages(self, messages: list[ModelMessage]) -> list[MistralMessages]: """Just maps a `pydantic_ai.Message` to a `MistralMessage`.""" mistral_messages: list[MistralMessages] = [] for message in messages: if isinstance(message, ModelRequest): mistral_messages.extend(self._map_user_message(message)) elif isinstance(message, ModelResponse): content_chunks: list[MistralContentChunk] = [] tool_calls: list[MistralToolCall] = [] for part in message.parts: if isinstance(part, TextPart): content_chunks.append(MistralTextChunk(text=part.content)) elif isinstance(part, ThinkingPart): # NOTE: We don't send ThinkingPart to the providers yet. If you are unsatisfied with this, # please open an issue. The below code is the code to send thinking to the provider. # content_chunks.append(MistralTextChunk(text=f'{part.content}')) pass elif isinstance(part, ToolCallPart): tool_calls.append(self._map_tool_call(part)) else: assert_never(part) mistral_messages.append(MistralAssistantMessage(content=content_chunks, tool_calls=tool_calls)) else: assert_never(message) if instructions := self._get_instructions(messages): mistral_messages.insert(0, MistralSystemMessage(content=instructions)) # Post-process messages to insert fake assistant message after tool message if followed by user message # to work around `Unexpected role 'user' after role 'tool'` error. processed_messages: list[MistralMessages] = [] for i, current_message in enumerate(mistral_messages): processed_messages.append(current_message) if isinstance(current_message, MistralToolMessage) and i + 1 < len(mistral_messages): next_message = mistral_messages[i + 1] if isinstance(next_message, MistralUserMessage): # Insert a dummy assistant message processed_messages.append(MistralAssistantMessage(content=[MistralTextChunk(text='OK')])) return processed_messages def _map_user_prompt(self, part: UserPromptPart) -> MistralUserMessage: content: str | list[MistralContentChunk] if isinstance(part.content, str): content = part.content else: content = [] for item in part.content: if isinstance(item, str): content.append(MistralTextChunk(text=item)) elif isinstance(item, ImageUrl): content.append(MistralImageURLChunk(image_url=MistralImageURL(url=item.url))) elif isinstance(item, BinaryContent): base64_encoded = base64.b64encode(item.data).decode('utf-8') if item.is_image: image_url = MistralImageURL(url=f'data:{item.media_type};base64,{base64_encoded}') content.append(MistralImageURLChunk(image_url=image_url, type='image_url')) else: raise RuntimeError('Only image binary content is supported for Mistral.') elif isinstance(item, DocumentUrl): raise RuntimeError('DocumentUrl is not supported in Mistral.') # pragma: no cover elif isinstance(item, VideoUrl): raise RuntimeError('VideoUrl is not supported in Mistral.') else: # pragma: no cover raise RuntimeError(f'Unsupported content type: {type(item)}') return MistralUserMessage(content=content) ```` #### __init__ ````python __init__( model_name: MistralModelName, *, provider: ( Literal["mistral"] | Provider[Mistral] ) = "mistral", profile: ModelProfileSpec | None = None, json_mode_schema_prompt: str = "Answer in JSON Object, respect the format:\n```\n{schema}\n```\n" ) ```` Initialize a Mistral model. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `model_name` | `MistralModelName` | The name of the model to use. | *required* | | `provider` | `Literal['mistral'] | Provider[Mistral]` | The provider to use for authentication and API access. Can be either the string 'mistral' or an instance of Provider[Mistral]. If not provided, a new provider will be created using the other parameters. | `'mistral'` | | `profile` | `ModelProfileSpec | None` | The model profile to use. Defaults to a profile picked by the provider based on the model name. | `None` | | `json_mode_schema_prompt` | `str` | The prompt to show when the model expects a JSON object as input. | ```` 'Answer in JSON Object, respect the format:\n```\n{schema}\n```\n' ```` | Source code in `pydantic_ai_slim/pydantic_ai/models/mistral.py` ````python def __init__( self, model_name: MistralModelName, *, provider: Literal['mistral'] | Provider[Mistral] = 'mistral', profile: ModelProfileSpec | None = None, json_mode_schema_prompt: str = """Answer in JSON Object, respect the format:\n```\n{schema}\n```\n""", ): """Initialize a Mistral model. Args: model_name: The name of the model to use. provider: The provider to use for authentication and API access. Can be either the string 'mistral' or an instance of `Provider[Mistral]`. If not provided, a new provider will be created using the other parameters. profile: The model profile to use. Defaults to a profile picked by the provider based on the model name. json_mode_schema_prompt: The prompt to show when the model expects a JSON object as input. """ self._model_name = model_name self.json_mode_schema_prompt = json_mode_schema_prompt if isinstance(provider, str): provider = infer_provider(provider) self.client = provider.client self._profile = profile or provider.model_profile ```` #### request ```python request( messages: list[ModelMessage], model_settings: ModelSettings | None, model_request_parameters: ModelRequestParameters, ) -> ModelResponse ``` Make a non-streaming request to the model from Pydantic AI call. Source code in `pydantic_ai_slim/pydantic_ai/models/mistral.py` ```python async def request( self, messages: list[ModelMessage], model_settings: ModelSettings | None, model_request_parameters: ModelRequestParameters, ) -> ModelResponse: """Make a non-streaming request to the model from Pydantic AI call.""" check_allow_model_requests() response = await self._completions_create( messages, cast(MistralModelSettings, model_settings or {}), model_request_parameters ) model_response = self._process_response(response) model_response.usage.requests = 1 return model_response ``` #### request_stream ```python request_stream( messages: list[ModelMessage], model_settings: ModelSettings | None, model_request_parameters: ModelRequestParameters, ) -> AsyncIterator[StreamedResponse] ``` Make a streaming request to the model from Pydantic AI call. Source code in `pydantic_ai_slim/pydantic_ai/models/mistral.py` ```python @asynccontextmanager async def request_stream( self, messages: list[ModelMessage], model_settings: ModelSettings | None, model_request_parameters: ModelRequestParameters, ) -> AsyncIterator[StreamedResponse]: """Make a streaming request to the model from Pydantic AI call.""" check_allow_model_requests() response = await self._stream_completions_create( messages, cast(MistralModelSettings, model_settings or {}), model_request_parameters ) async with response: yield await self._process_streamed_response(model_request_parameters.output_tools, response) ``` #### model_name ```python model_name: MistralModelName ``` The model name. #### system ```python system: str ``` The system / model provider. ### MistralStreamedResponse Bases: `StreamedResponse` Implementation of `StreamedResponse` for Mistral models. Source code in `pydantic_ai_slim/pydantic_ai/models/mistral.py` ```python @dataclass class MistralStreamedResponse(StreamedResponse): """Implementation of `StreamedResponse` for Mistral models.""" _model_name: MistralModelName _response: AsyncIterable[MistralCompletionEvent] _timestamp: datetime _output_tools: dict[str, ToolDefinition] _delta_content: str = field(default='', init=False) async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: chunk: MistralCompletionEvent async for chunk in self._response: self._usage += _map_usage(chunk.data) try: choice = chunk.data.choices[0] except IndexError: continue # Handle the text part of the response content = choice.delta.content text = _map_content(content) if text: # Attempt to produce an output tool call from the received text if self._output_tools: self._delta_content += text # TODO: Port to native "manual JSON" mode maybe_tool_call_part = self._try_get_output_tool_from_text(self._delta_content, self._output_tools) if maybe_tool_call_part: yield self._parts_manager.handle_tool_call_part( vendor_part_id='output', tool_name=maybe_tool_call_part.tool_name, args=maybe_tool_call_part.args_as_dict(), tool_call_id=maybe_tool_call_part.tool_call_id, ) else: yield self._parts_manager.handle_text_delta(vendor_part_id='content', content=text) # Handle the explicit tool calls for index, dtc in enumerate(choice.delta.tool_calls or []): # It seems that mistral just sends full tool calls, so we just use them directly, rather than building yield self._parts_manager.handle_tool_call_part( vendor_part_id=index, tool_name=dtc.function.name, args=dtc.function.arguments, tool_call_id=dtc.id ) @property def model_name(self) -> MistralModelName: """Get the model name of the response.""" return self._model_name @property def timestamp(self) -> datetime: """Get the timestamp of the response.""" return self._timestamp @staticmethod def _try_get_output_tool_from_text(text: str, output_tools: dict[str, ToolDefinition]) -> ToolCallPart | None: output_json: dict[str, Any] | None = pydantic_core.from_json(text, allow_partial='trailing-strings') if output_json: for output_tool in output_tools.values(): # NOTE: Additional verification to prevent JSON validation to crash # Ensures required parameters in the JSON schema are respected, especially for stream-based return types. # Example with BaseModel and required fields. if not MistralStreamedResponse._validate_required_json_schema( output_json, output_tool.parameters_json_schema ): continue # The following part_id will be thrown away return ToolCallPart(tool_name=output_tool.name, args=output_json) @staticmethod def _validate_required_json_schema(json_dict: dict[str, Any], json_schema: dict[str, Any]) -> bool: """Validate that all required parameters in the JSON schema are present in the JSON dictionary.""" required_params = json_schema.get('required', []) properties = json_schema.get('properties', {}) for param in required_params: if param not in json_dict: return False param_schema = properties.get(param, {}) param_type = param_schema.get('type') param_items_type = param_schema.get('items', {}).get('type') if param_type == 'array' and param_items_type: if not isinstance(json_dict[param], list): return False for item in json_dict[param]: if not isinstance(item, VALID_JSON_TYPE_MAPPING[param_items_type]): return False elif param_type and not isinstance(json_dict[param], VALID_JSON_TYPE_MAPPING[param_type]): return False if isinstance(json_dict[param], dict) and 'properties' in param_schema: nested_schema = param_schema if not MistralStreamedResponse._validate_required_json_schema(json_dict[param], nested_schema): return False return True ``` #### model_name ```python model_name: MistralModelName ``` Get the model name of the response. #### timestamp ```python timestamp: datetime ``` Get the timestamp of the response. # `pydantic_ai.models.openai` ## Setup For details on how to set up authentication with this model, see [model configuration for OpenAI](../../../models/openai/). ### OpenAIModelName ```python OpenAIModelName = Union[str, ChatModel] ``` Possible OpenAI model names. Since OpenAI supports a variety of date-stamped models, we explicitly list the latest models but allow any name in the type hints. See [the OpenAI docs](https://platform.openai.com/docs/models) for a full list. Using this more broad type for the model name instead of the ChatModel definition allows this model to be used more easily with other model types (ie, Ollama, Deepseek). ### OpenAIModelSettings Bases: `ModelSettings` Settings used for an OpenAI model request. Source code in `pydantic_ai_slim/pydantic_ai/models/openai.py` ```python class OpenAIModelSettings(ModelSettings, total=False): """Settings used for an OpenAI model request.""" # ALL FIELDS MUST BE `openai_` PREFIXED SO YOU CAN MERGE THEM WITH OTHER MODELS. openai_reasoning_effort: ReasoningEffort """Constrains effort on reasoning for [reasoning models](https://platform.openai.com/docs/guides/reasoning). Currently supported values are `low`, `medium`, and `high`. Reducing reasoning effort can result in faster responses and fewer tokens used on reasoning in a response. """ openai_logprobs: bool """Include log probabilities in the response.""" openai_top_logprobs: int """Include log probabilities of the top n tokens in the response.""" openai_user: str """A unique identifier representing the end-user, which can help OpenAI monitor and detect abuse. See [OpenAI's safety best practices](https://platform.openai.com/docs/guides/safety-best-practices#end-user-ids) for more details. """ openai_service_tier: Literal['auto', 'default', 'flex'] """The service tier to use for the model request. Currently supported values are `auto`, `default`, and `flex`. For more information, see [OpenAI's service tiers documentation](https://platform.openai.com/docs/api-reference/chat/object#chat/object-service_tier). """ openai_prediction: ChatCompletionPredictionContentParam """Enables [predictive outputs](https://platform.openai.com/docs/guides/predicted-outputs). This feature is currently only supported for some OpenAI models. """ ``` #### openai_reasoning_effort ```python openai_reasoning_effort: ReasoningEffort ``` Constrains effort on reasoning for [reasoning models](https://platform.openai.com/docs/guides/reasoning). Currently supported values are `low`, `medium`, and `high`. Reducing reasoning effort can result in faster responses and fewer tokens used on reasoning in a response. #### openai_logprobs ```python openai_logprobs: bool ``` Include log probabilities in the response. #### openai_top_logprobs ```python openai_top_logprobs: int ``` Include log probabilities of the top n tokens in the response. #### openai_user ```python openai_user: str ``` A unique identifier representing the end-user, which can help OpenAI monitor and detect abuse. See [OpenAI's safety best practices](https://platform.openai.com/docs/guides/safety-best-practices#end-user-ids) for more details. #### openai_service_tier ```python openai_service_tier: Literal['auto', 'default', 'flex'] ``` The service tier to use for the model request. Currently supported values are `auto`, `default`, and `flex`. For more information, see [OpenAI's service tiers documentation](https://platform.openai.com/docs/api-reference/chat/object#chat/object-service_tier). #### openai_prediction ```python openai_prediction: ChatCompletionPredictionContentParam ``` Enables [predictive outputs](https://platform.openai.com/docs/guides/predicted-outputs). This feature is currently only supported for some OpenAI models. ### OpenAIResponsesModelSettings Bases: `OpenAIModelSettings` Settings used for an OpenAI Responses model request. ALL FIELDS MUST BE `openai_` PREFIXED SO YOU CAN MERGE THEM WITH OTHER MODELS. Source code in `pydantic_ai_slim/pydantic_ai/models/openai.py` ```python class OpenAIResponsesModelSettings(OpenAIModelSettings, total=False): """Settings used for an OpenAI Responses model request. ALL FIELDS MUST BE `openai_` PREFIXED SO YOU CAN MERGE THEM WITH OTHER MODELS. """ openai_builtin_tools: Sequence[FileSearchToolParam | WebSearchToolParam | ComputerToolParam] """The provided OpenAI built-in tools to use. See [OpenAI's built-in tools](https://platform.openai.com/docs/guides/tools?api-mode=responses) for more details. """ openai_reasoning_generate_summary: Literal['detailed', 'concise'] """Deprecated alias for `openai_reasoning_summary`.""" openai_reasoning_summary: Literal['detailed', 'concise'] """A summary of the reasoning performed by the model. This can be useful for debugging and understanding the model's reasoning process. One of `concise` or `detailed`. Check the [OpenAI Computer use documentation](https://platform.openai.com/docs/guides/tools-computer-use#1-send-a-request-to-the-model) for more details. """ openai_truncation: Literal['disabled', 'auto'] """The truncation strategy to use for the model response. It can be either: - `disabled` (default): If a model response will exceed the context window size for a model, the request will fail with a 400 error. - `auto`: If the context of this response and previous ones exceeds the model's context window size, the model will truncate the response to fit the context window by dropping input items in the middle of the conversation. """ ``` #### openai_builtin_tools ```python openai_builtin_tools: Sequence[ FileSearchToolParam | WebSearchToolParam | ComputerToolParam ] ``` The provided OpenAI built-in tools to use. See [OpenAI's built-in tools](https://platform.openai.com/docs/guides/tools?api-mode=responses) for more details. #### openai_reasoning_generate_summary ```python openai_reasoning_generate_summary: Literal[ "detailed", "concise" ] ``` Deprecated alias for `openai_reasoning_summary`. #### openai_reasoning_summary ```python openai_reasoning_summary: Literal['detailed', 'concise'] ``` A summary of the reasoning performed by the model. This can be useful for debugging and understanding the model's reasoning process. One of `concise` or `detailed`. Check the [OpenAI Computer use documentation](https://platform.openai.com/docs/guides/tools-computer-use#1-send-a-request-to-the-model) for more details. #### openai_truncation ```python openai_truncation: Literal['disabled', 'auto'] ``` The truncation strategy to use for the model response. It can be either: - `disabled` (default): If a model response will exceed the context window size for a model, the request will fail with a 400 error. - `auto`: If the context of this response and previous ones exceeds the model's context window size, the model will truncate the response to fit the context window by dropping input items in the middle of the conversation. ### OpenAIModel Bases: `Model` A model that uses the OpenAI API. Internally, this uses the [OpenAI Python client](https://github.com/openai/openai-python) to interact with the API. Apart from `__init__`, all methods are private or match those of the base class. Source code in `pydantic_ai_slim/pydantic_ai/models/openai.py` ```python @dataclass(init=False) class OpenAIModel(Model): """A model that uses the OpenAI API. Internally, this uses the [OpenAI Python client](https://github.com/openai/openai-python) to interact with the API. Apart from `__init__`, all methods are private or match those of the base class. """ client: AsyncOpenAI = field(repr=False) system_prompt_role: OpenAISystemPromptRole | None = field(default=None, repr=False) _model_name: OpenAIModelName = field(repr=False) _system: str = field(default='openai', repr=False) def __init__( self, model_name: OpenAIModelName, *, provider: Literal[ 'openai', 'deepseek', 'azure', 'openrouter', 'grok', 'fireworks', 'together', 'heroku', 'github' ] | Provider[AsyncOpenAI] = 'openai', profile: ModelProfileSpec | None = None, system_prompt_role: OpenAISystemPromptRole | None = None, ): """Initialize an OpenAI model. Args: model_name: The name of the OpenAI model to use. List of model names available [here](https://github.com/openai/openai-python/blob/v1.54.3/src/openai/types/chat_model.py#L7) (Unfortunately, despite being ask to do so, OpenAI do not provide `.inv` files for their API). provider: The provider to use. Defaults to `'openai'`. profile: The model profile to use. Defaults to a profile picked by the provider based on the model name. system_prompt_role: The role to use for the system prompt message. If not provided, defaults to `'system'`. In the future, this may be inferred from the model name. """ self._model_name = model_name if isinstance(provider, str): provider = infer_provider(provider) self.client = provider.client self._profile = profile or provider.model_profile self.system_prompt_role = system_prompt_role @property def base_url(self) -> str: return str(self.client.base_url) async def request( self, messages: list[ModelMessage], model_settings: ModelSettings | None, model_request_parameters: ModelRequestParameters, ) -> ModelResponse: check_allow_model_requests() response = await self._completions_create( messages, False, cast(OpenAIModelSettings, model_settings or {}), model_request_parameters ) model_response = self._process_response(response) model_response.usage.requests = 1 return model_response @asynccontextmanager async def request_stream( self, messages: list[ModelMessage], model_settings: ModelSettings | None, model_request_parameters: ModelRequestParameters, ) -> AsyncIterator[StreamedResponse]: check_allow_model_requests() response = await self._completions_create( messages, True, cast(OpenAIModelSettings, model_settings or {}), model_request_parameters ) async with response: yield await self._process_streamed_response(response) @property def model_name(self) -> OpenAIModelName: """The model name.""" return self._model_name @property def system(self) -> str: """The system / model provider.""" return self._system @overload async def _completions_create( self, messages: list[ModelMessage], stream: Literal[True], model_settings: OpenAIModelSettings, model_request_parameters: ModelRequestParameters, ) -> AsyncStream[ChatCompletionChunk]: ... @overload async def _completions_create( self, messages: list[ModelMessage], stream: Literal[False], model_settings: OpenAIModelSettings, model_request_parameters: ModelRequestParameters, ) -> chat.ChatCompletion: ... async def _completions_create( self, messages: list[ModelMessage], stream: bool, model_settings: OpenAIModelSettings, model_request_parameters: ModelRequestParameters, ) -> chat.ChatCompletion | AsyncStream[ChatCompletionChunk]: tools = self._get_tools(model_request_parameters) if not tools: tool_choice: Literal['none', 'required', 'auto'] | None = None elif not model_request_parameters.allow_text_output: tool_choice = 'required' else: tool_choice = 'auto' openai_messages = await self._map_messages(messages) response_format: chat.completion_create_params.ResponseFormat | None = None if model_request_parameters.output_mode == 'native': output_object = model_request_parameters.output_object assert output_object is not None response_format = self._map_json_schema(output_object) elif ( model_request_parameters.output_mode == 'prompted' and self.profile.supports_json_object_output ): # pragma: no branch response_format = {'type': 'json_object'} sampling_settings = ( model_settings if OpenAIModelProfile.from_profile(self.profile).openai_supports_sampling_settings else OpenAIModelSettings() ) try: extra_headers = model_settings.get('extra_headers', {}) extra_headers.setdefault('User-Agent', get_user_agent()) return await self.client.chat.completions.create( model=self._model_name, messages=openai_messages, parallel_tool_calls=model_settings.get('parallel_tool_calls', NOT_GIVEN), tools=tools or NOT_GIVEN, tool_choice=tool_choice or NOT_GIVEN, stream=stream, stream_options={'include_usage': True} if stream else NOT_GIVEN, stop=model_settings.get('stop_sequences', NOT_GIVEN), max_completion_tokens=model_settings.get('max_tokens', NOT_GIVEN), timeout=model_settings.get('timeout', NOT_GIVEN), response_format=response_format or NOT_GIVEN, seed=model_settings.get('seed', NOT_GIVEN), reasoning_effort=model_settings.get('openai_reasoning_effort', NOT_GIVEN), user=model_settings.get('openai_user', NOT_GIVEN), service_tier=model_settings.get('openai_service_tier', NOT_GIVEN), prediction=model_settings.get('openai_prediction', NOT_GIVEN), temperature=sampling_settings.get('temperature', NOT_GIVEN), top_p=sampling_settings.get('top_p', NOT_GIVEN), presence_penalty=sampling_settings.get('presence_penalty', NOT_GIVEN), frequency_penalty=sampling_settings.get('frequency_penalty', NOT_GIVEN), logit_bias=sampling_settings.get('logit_bias', NOT_GIVEN), logprobs=sampling_settings.get('openai_logprobs', NOT_GIVEN), top_logprobs=sampling_settings.get('openai_top_logprobs', NOT_GIVEN), extra_headers=extra_headers, extra_body=model_settings.get('extra_body'), ) except APIStatusError as e: if (status_code := e.status_code) >= 400: raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=e.body) from e raise # pragma: lax no cover def _process_response(self, response: chat.ChatCompletion) -> ModelResponse: """Process a non-streamed response, and prepare a message to return.""" timestamp = number_to_datetime(response.created) choice = response.choices[0] items: list[ModelResponsePart] = [] # The `reasoning_content` is only present in DeepSeek models. if reasoning_content := getattr(choice.message, 'reasoning_content', None): items.append(ThinkingPart(content=reasoning_content)) vendor_details: dict[str, Any] | None = None # Add logprobs to vendor_details if available if choice.logprobs is not None and choice.logprobs.content: # Convert logprobs to a serializable format vendor_details = { 'logprobs': [ { 'token': lp.token, 'bytes': lp.bytes, 'logprob': lp.logprob, 'top_logprobs': [ {'token': tlp.token, 'bytes': tlp.bytes, 'logprob': tlp.logprob} for tlp in lp.top_logprobs ], } for lp in choice.logprobs.content ], } if choice.message.content is not None: items.extend(split_content_into_text_and_thinking(choice.message.content)) if choice.message.tool_calls is not None: for c in choice.message.tool_calls: part = ToolCallPart(c.function.name, c.function.arguments, tool_call_id=c.id) part.tool_call_id = _guard_tool_call_id(part) items.append(part) return ModelResponse( items, usage=_map_usage(response), model_name=response.model, timestamp=timestamp, vendor_details=vendor_details, vendor_id=response.id, ) async def _process_streamed_response(self, response: AsyncStream[ChatCompletionChunk]) -> OpenAIStreamedResponse: """Process a streamed response, and prepare a streaming response to return.""" peekable_response = _utils.PeekableAsyncStream(response) first_chunk = await peekable_response.peek() if isinstance(first_chunk, _utils.Unset): raise UnexpectedModelBehavior( # pragma: no cover 'Streamed response ended without content or tool calls' ) return OpenAIStreamedResponse( _model_name=self._model_name, _response=peekable_response, _timestamp=number_to_datetime(first_chunk.created), ) def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[chat.ChatCompletionToolParam]: tools = [self._map_tool_definition(r) for r in model_request_parameters.function_tools] if model_request_parameters.output_tools: tools += [self._map_tool_definition(r) for r in model_request_parameters.output_tools] return tools async def _map_messages(self, messages: list[ModelMessage]) -> list[chat.ChatCompletionMessageParam]: """Just maps a `pydantic_ai.Message` to a `openai.types.ChatCompletionMessageParam`.""" openai_messages: list[chat.ChatCompletionMessageParam] = [] for message in messages: if isinstance(message, ModelRequest): async for item in self._map_user_message(message): openai_messages.append(item) elif isinstance(message, ModelResponse): texts: list[str] = [] tool_calls: list[chat.ChatCompletionMessageToolCallParam] = [] for item in message.parts: if isinstance(item, TextPart): texts.append(item.content) elif isinstance(item, ThinkingPart): # NOTE: We don't send ThinkingPart to the providers yet. If you are unsatisfied with this, # please open an issue. The below code is the code to send thinking to the provider. # texts.append(f'\n{item.content}\n') pass elif isinstance(item, ToolCallPart): tool_calls.append(self._map_tool_call(item)) else: assert_never(item) message_param = chat.ChatCompletionAssistantMessageParam(role='assistant') if texts: # Note: model responses from this model should only have one text item, so the following # shouldn't merge multiple texts into one unless you switch models between runs: message_param['content'] = '\n\n'.join(texts) if tool_calls: message_param['tool_calls'] = tool_calls openai_messages.append(message_param) else: assert_never(message) if instructions := self._get_instructions(messages): openai_messages.insert(0, chat.ChatCompletionSystemMessageParam(content=instructions, role='system')) return openai_messages @staticmethod def _map_tool_call(t: ToolCallPart) -> chat.ChatCompletionMessageToolCallParam: return chat.ChatCompletionMessageToolCallParam( id=_guard_tool_call_id(t=t), type='function', function={'name': t.tool_name, 'arguments': t.args_as_json_str()}, ) def _map_json_schema(self, o: OutputObjectDefinition) -> chat.completion_create_params.ResponseFormat: response_format_param: chat.completion_create_params.ResponseFormatJSONSchema = { # pyright: ignore[reportPrivateImportUsage] 'type': 'json_schema', 'json_schema': {'name': o.name or DEFAULT_OUTPUT_TOOL_NAME, 'schema': o.json_schema, 'strict': True}, } if o.description: response_format_param['json_schema']['description'] = o.description if OpenAIModelProfile.from_profile(self.profile).openai_supports_strict_tool_definition: # pragma: no branch response_format_param['json_schema']['strict'] = o.strict return response_format_param def _map_tool_definition(self, f: ToolDefinition) -> chat.ChatCompletionToolParam: tool_param: chat.ChatCompletionToolParam = { 'type': 'function', 'function': { 'name': f.name, 'description': f.description, 'parameters': f.parameters_json_schema, }, } if f.strict and OpenAIModelProfile.from_profile(self.profile).openai_supports_strict_tool_definition: tool_param['function']['strict'] = f.strict return tool_param async def _map_user_message(self, message: ModelRequest) -> AsyncIterable[chat.ChatCompletionMessageParam]: for part in message.parts: if isinstance(part, SystemPromptPart): if self.system_prompt_role == 'developer': yield chat.ChatCompletionDeveloperMessageParam(role='developer', content=part.content) elif self.system_prompt_role == 'user': yield chat.ChatCompletionUserMessageParam(role='user', content=part.content) else: yield chat.ChatCompletionSystemMessageParam(role='system', content=part.content) elif isinstance(part, UserPromptPart): yield await self._map_user_prompt(part) elif isinstance(part, ToolReturnPart): yield chat.ChatCompletionToolMessageParam( role='tool', tool_call_id=_guard_tool_call_id(t=part), content=part.model_response_str(), ) elif isinstance(part, RetryPromptPart): if part.tool_name is None: yield chat.ChatCompletionUserMessageParam( # pragma: no cover role='user', content=part.model_response() ) else: yield chat.ChatCompletionToolMessageParam( role='tool', tool_call_id=_guard_tool_call_id(t=part), content=part.model_response(), ) else: assert_never(part) @staticmethod async def _map_user_prompt(part: UserPromptPart) -> chat.ChatCompletionUserMessageParam: content: str | list[ChatCompletionContentPartParam] if isinstance(part.content, str): content = part.content else: content = [] for item in part.content: if isinstance(item, str): content.append(ChatCompletionContentPartTextParam(text=item, type='text')) elif isinstance(item, ImageUrl): image_url = ImageURL(url=item.url) content.append(ChatCompletionContentPartImageParam(image_url=image_url, type='image_url')) elif isinstance(item, BinaryContent): base64_encoded = base64.b64encode(item.data).decode('utf-8') if item.is_image: image_url = ImageURL(url=f'data:{item.media_type};base64,{base64_encoded}') content.append(ChatCompletionContentPartImageParam(image_url=image_url, type='image_url')) elif item.is_audio: assert item.format in ('wav', 'mp3') audio = InputAudio(data=base64_encoded, format=item.format) content.append(ChatCompletionContentPartInputAudioParam(input_audio=audio, type='input_audio')) elif item.is_document: content.append( File( file=FileFile( file_data=f'data:{item.media_type};base64,{base64_encoded}', filename=f'filename.{item.format}', ), type='file', ) ) else: # pragma: no cover raise RuntimeError(f'Unsupported binary content type: {item.media_type}') elif isinstance(item, AudioUrl): downloaded_item = await download_item(item, data_format='base64', type_format='extension') assert downloaded_item['data_type'] in ( 'wav', 'mp3', ), f'Unsupported audio format: {downloaded_item["data_type"]}' audio = InputAudio(data=downloaded_item['data'], format=downloaded_item['data_type']) content.append(ChatCompletionContentPartInputAudioParam(input_audio=audio, type='input_audio')) elif isinstance(item, DocumentUrl): downloaded_item = await download_item(item, data_format='base64_uri', type_format='extension') file = File( file=FileFile( file_data=downloaded_item['data'], filename=f'filename.{downloaded_item["data_type"]}' ), type='file', ) content.append(file) elif isinstance(item, VideoUrl): # pragma: no cover raise NotImplementedError('VideoUrl is not supported for OpenAI') else: assert_never(item) return chat.ChatCompletionUserMessageParam(role='user', content=content) ``` #### __init__ ```python __init__( model_name: OpenAIModelName, *, provider: ( Literal[ "openai", "deepseek", "azure", "openrouter", "grok", "fireworks", "together", "heroku", "github", ] | Provider[AsyncOpenAI] ) = "openai", profile: ModelProfileSpec | None = None, system_prompt_role: OpenAISystemPromptRole | None = None ) ``` Initialize an OpenAI model. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `model_name` | `OpenAIModelName` | The name of the OpenAI model to use. List of model names available here (Unfortunately, despite being ask to do so, OpenAI do not provide .inv files for their API). | *required* | | `provider` | `Literal['openai', 'deepseek', 'azure', 'openrouter', 'grok', 'fireworks', 'together', 'heroku', 'github'] | Provider[AsyncOpenAI]` | The provider to use. Defaults to 'openai'. | `'openai'` | | `profile` | `ModelProfileSpec | None` | The model profile to use. Defaults to a profile picked by the provider based on the model name. | `None` | | `system_prompt_role` | `OpenAISystemPromptRole | None` | The role to use for the system prompt message. If not provided, defaults to 'system'. In the future, this may be inferred from the model name. | `None` | Source code in `pydantic_ai_slim/pydantic_ai/models/openai.py` ```python def __init__( self, model_name: OpenAIModelName, *, provider: Literal[ 'openai', 'deepseek', 'azure', 'openrouter', 'grok', 'fireworks', 'together', 'heroku', 'github' ] | Provider[AsyncOpenAI] = 'openai', profile: ModelProfileSpec | None = None, system_prompt_role: OpenAISystemPromptRole | None = None, ): """Initialize an OpenAI model. Args: model_name: The name of the OpenAI model to use. List of model names available [here](https://github.com/openai/openai-python/blob/v1.54.3/src/openai/types/chat_model.py#L7) (Unfortunately, despite being ask to do so, OpenAI do not provide `.inv` files for their API). provider: The provider to use. Defaults to `'openai'`. profile: The model profile to use. Defaults to a profile picked by the provider based on the model name. system_prompt_role: The role to use for the system prompt message. If not provided, defaults to `'system'`. In the future, this may be inferred from the model name. """ self._model_name = model_name if isinstance(provider, str): provider = infer_provider(provider) self.client = provider.client self._profile = profile or provider.model_profile self.system_prompt_role = system_prompt_role ``` #### model_name ```python model_name: OpenAIModelName ``` The model name. #### system ```python system: str ``` The system / model provider. ### OpenAIResponsesModel Bases: `Model` A model that uses the OpenAI Responses API. The [OpenAI Responses API](https://platform.openai.com/docs/api-reference/responses) is the new API for OpenAI models. The Responses API has built-in tools, that you can use instead of building your own: - [Web search](https://platform.openai.com/docs/guides/tools-web-search) - [File search](https://platform.openai.com/docs/guides/tools-file-search) - [Computer use](https://platform.openai.com/docs/guides/tools-computer-use) Use the `openai_builtin_tools` setting to add these tools to your model. If you are interested in the differences between the Responses API and the Chat Completions API, see the [OpenAI API docs](https://platform.openai.com/docs/guides/responses-vs-chat-completions). Source code in `pydantic_ai_slim/pydantic_ai/models/openai.py` ```python @dataclass(init=False) class OpenAIResponsesModel(Model): """A model that uses the OpenAI Responses API. The [OpenAI Responses API](https://platform.openai.com/docs/api-reference/responses) is the new API for OpenAI models. The Responses API has built-in tools, that you can use instead of building your own: - [Web search](https://platform.openai.com/docs/guides/tools-web-search) - [File search](https://platform.openai.com/docs/guides/tools-file-search) - [Computer use](https://platform.openai.com/docs/guides/tools-computer-use) Use the `openai_builtin_tools` setting to add these tools to your model. If you are interested in the differences between the Responses API and the Chat Completions API, see the [OpenAI API docs](https://platform.openai.com/docs/guides/responses-vs-chat-completions). """ client: AsyncOpenAI = field(repr=False) system_prompt_role: OpenAISystemPromptRole | None = field(default=None) _model_name: OpenAIModelName = field(repr=False) _system: str = field(default='openai', repr=False) def __init__( self, model_name: OpenAIModelName, *, provider: Literal['openai', 'deepseek', 'azure', 'openrouter', 'grok', 'fireworks', 'together'] | Provider[AsyncOpenAI] = 'openai', profile: ModelProfileSpec | None = None, ): """Initialize an OpenAI Responses model. Args: model_name: The name of the OpenAI model to use. provider: The provider to use. Defaults to `'openai'`. profile: The model profile to use. Defaults to a profile picked by the provider based on the model name. """ self._model_name = model_name if isinstance(provider, str): provider = infer_provider(provider) self.client = provider.client self._profile = profile or provider.model_profile @property def model_name(self) -> OpenAIModelName: """The model name.""" return self._model_name @property def system(self) -> str: """The system / model provider.""" return self._system async def request( self, messages: list[ModelRequest | ModelResponse], model_settings: ModelSettings | None, model_request_parameters: ModelRequestParameters, ) -> ModelResponse: check_allow_model_requests() response = await self._responses_create( messages, False, cast(OpenAIResponsesModelSettings, model_settings or {}), model_request_parameters ) return self._process_response(response) @asynccontextmanager async def request_stream( self, messages: list[ModelMessage], model_settings: ModelSettings | None, model_request_parameters: ModelRequestParameters, ) -> AsyncIterator[StreamedResponse]: check_allow_model_requests() response = await self._responses_create( messages, True, cast(OpenAIResponsesModelSettings, model_settings or {}), model_request_parameters ) async with response: yield await self._process_streamed_response(response) def _process_response(self, response: responses.Response) -> ModelResponse: """Process a non-streamed response, and prepare a message to return.""" timestamp = number_to_datetime(response.created_at) items: list[ModelResponsePart] = [] for item in response.output: if item.type == 'reasoning': for summary in item.summary: # NOTE: We use the same id for all summaries because we can merge them on the round trip. # The providers don't force the signature to be unique. items.append(ThinkingPart(content=summary.text, id=item.id)) elif item.type == 'message': for content in item.content: if content.type == 'output_text': # pragma: no branch items.append(TextPart(content.text)) elif item.type == 'function_call': items.append(ToolCallPart(item.name, item.arguments, tool_call_id=item.call_id)) return ModelResponse( items, usage=_map_usage(response), model_name=response.model, vendor_id=response.id, timestamp=timestamp, ) async def _process_streamed_response( self, response: AsyncStream[responses.ResponseStreamEvent] ) -> OpenAIResponsesStreamedResponse: """Process a streamed response, and prepare a streaming response to return.""" peekable_response = _utils.PeekableAsyncStream(response) first_chunk = await peekable_response.peek() if isinstance(first_chunk, _utils.Unset): # pragma: no cover raise UnexpectedModelBehavior('Streamed response ended without content or tool calls') assert isinstance(first_chunk, responses.ResponseCreatedEvent) return OpenAIResponsesStreamedResponse( _model_name=self._model_name, _response=peekable_response, _timestamp=number_to_datetime(first_chunk.response.created_at), ) @overload async def _responses_create( self, messages: list[ModelRequest | ModelResponse], stream: Literal[False], model_settings: OpenAIResponsesModelSettings, model_request_parameters: ModelRequestParameters, ) -> responses.Response: ... @overload async def _responses_create( self, messages: list[ModelRequest | ModelResponse], stream: Literal[True], model_settings: OpenAIResponsesModelSettings, model_request_parameters: ModelRequestParameters, ) -> AsyncStream[responses.ResponseStreamEvent]: ... async def _responses_create( self, messages: list[ModelRequest | ModelResponse], stream: bool, model_settings: OpenAIResponsesModelSettings, model_request_parameters: ModelRequestParameters, ) -> responses.Response | AsyncStream[responses.ResponseStreamEvent]: tools = self._get_tools(model_request_parameters) tools = list(model_settings.get('openai_builtin_tools', [])) + tools if not tools: tool_choice: Literal['none', 'required', 'auto'] | None = None elif not model_request_parameters.allow_text_output: tool_choice = 'required' else: tool_choice = 'auto' instructions, openai_messages = await self._map_messages(messages) reasoning = self._get_reasoning(model_settings) text: responses.ResponseTextConfigParam | None = None if model_request_parameters.output_mode == 'native': output_object = model_request_parameters.output_object assert output_object is not None text = {'format': self._map_json_schema(output_object)} elif ( model_request_parameters.output_mode == 'prompted' and self.profile.supports_json_object_output ): # pragma: no branch text = {'format': {'type': 'json_object'}} # Without this trick, we'd hit this error: # > Response input messages must contain the word 'json' in some form to use 'text.format' of type 'json_object'. # Apparently they're only checking input messages for "JSON", not instructions. assert isinstance(instructions, str) openai_messages.insert(0, responses.EasyInputMessageParam(role='system', content=instructions)) instructions = NOT_GIVEN sampling_settings = ( model_settings if OpenAIModelProfile.from_profile(self.profile).openai_supports_sampling_settings else OpenAIResponsesModelSettings() ) try: extra_headers = model_settings.get('extra_headers', {}) extra_headers.setdefault('User-Agent', get_user_agent()) return await self.client.responses.create( input=openai_messages, model=self._model_name, instructions=instructions, parallel_tool_calls=model_settings.get('parallel_tool_calls', NOT_GIVEN), tools=tools or NOT_GIVEN, tool_choice=tool_choice or NOT_GIVEN, max_output_tokens=model_settings.get('max_tokens', NOT_GIVEN), stream=stream, temperature=sampling_settings.get('temperature', NOT_GIVEN), top_p=sampling_settings.get('top_p', NOT_GIVEN), truncation=model_settings.get('openai_truncation', NOT_GIVEN), timeout=model_settings.get('timeout', NOT_GIVEN), reasoning=reasoning, user=model_settings.get('openai_user', NOT_GIVEN), text=text or NOT_GIVEN, extra_headers=extra_headers, extra_body=model_settings.get('extra_body'), ) except APIStatusError as e: if (status_code := e.status_code) >= 400: raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=e.body) from e raise # pragma: lax no cover def _get_reasoning(self, model_settings: OpenAIResponsesModelSettings) -> Reasoning | NotGiven: reasoning_effort = model_settings.get('openai_reasoning_effort', None) reasoning_summary = model_settings.get('openai_reasoning_summary', None) reasoning_generate_summary = model_settings.get('openai_reasoning_generate_summary', None) if reasoning_summary and reasoning_generate_summary: # pragma: no cover raise ValueError('`openai_reasoning_summary` and `openai_reasoning_generate_summary` cannot both be set.') if reasoning_generate_summary is not None: # pragma: no cover warnings.warn( '`openai_reasoning_generate_summary` is deprecated, use `openai_reasoning_summary` instead', DeprecationWarning, ) reasoning_summary = reasoning_generate_summary if reasoning_effort is None and reasoning_summary is None: return NOT_GIVEN return Reasoning(effort=reasoning_effort, summary=reasoning_summary) def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[responses.FunctionToolParam]: tools = [self._map_tool_definition(r) for r in model_request_parameters.function_tools] if model_request_parameters.output_tools: tools += [self._map_tool_definition(r) for r in model_request_parameters.output_tools] return tools def _map_tool_definition(self, f: ToolDefinition) -> responses.FunctionToolParam: return { 'name': f.name, 'parameters': f.parameters_json_schema, 'type': 'function', 'description': f.description, 'strict': bool( f.strict and OpenAIModelProfile.from_profile(self.profile).openai_supports_strict_tool_definition ), } async def _map_messages( self, messages: list[ModelMessage] ) -> tuple[str | NotGiven, list[responses.ResponseInputItemParam]]: """Just maps a `pydantic_ai.Message` to a `openai.types.responses.ResponseInputParam`.""" openai_messages: list[responses.ResponseInputItemParam] = [] for message in messages: if isinstance(message, ModelRequest): for part in message.parts: if isinstance(part, SystemPromptPart): openai_messages.append(responses.EasyInputMessageParam(role='system', content=part.content)) elif isinstance(part, UserPromptPart): openai_messages.append(await self._map_user_prompt(part)) elif isinstance(part, ToolReturnPart): openai_messages.append( FunctionCallOutput( type='function_call_output', call_id=_guard_tool_call_id(t=part), output=part.model_response_str(), ) ) elif isinstance(part, RetryPromptPart): # TODO(Marcelo): How do we test this conditional branch? if part.tool_name is None: # pragma: no cover openai_messages.append( Message(role='user', content=[{'type': 'input_text', 'text': part.model_response()}]) ) else: openai_messages.append( FunctionCallOutput( type='function_call_output', call_id=_guard_tool_call_id(t=part), output=part.model_response(), ) ) else: assert_never(part) elif isinstance(message, ModelResponse): # last_thinking_part_idx: int | None = None for item in message.parts: if isinstance(item, TextPart): openai_messages.append(responses.EasyInputMessageParam(role='assistant', content=item.content)) elif isinstance(item, ToolCallPart): openai_messages.append(self._map_tool_call(item)) elif isinstance(item, ThinkingPart): # NOTE: We don't send ThinkingPart to the providers yet. If you are unsatisfied with this, # please open an issue. The below code is the code to send thinking to the provider. # if last_thinking_part_idx is not None: # reasoning_item = cast(responses.ResponseReasoningItemParam, openai_messages[last_thinking_part_idx]) # fmt: skip # if item.id == reasoning_item['id']: # assert isinstance(reasoning_item['summary'], list) # reasoning_item['summary'].append(Summary(text=item.content, type='summary_text')) # continue # last_thinking_part_idx = len(openai_messages) # openai_messages.append( # responses.ResponseReasoningItemParam( # id=item.id or generate_tool_call_id(), # summary=[Summary(text=item.content, type='summary_text')], # type='reasoning', # ) # ) pass else: assert_never(item) else: assert_never(message) instructions = self._get_instructions(messages) or NOT_GIVEN return instructions, openai_messages @staticmethod def _map_tool_call(t: ToolCallPart) -> responses.ResponseFunctionToolCallParam: return responses.ResponseFunctionToolCallParam( arguments=t.args_as_json_str(), call_id=_guard_tool_call_id(t=t), name=t.tool_name, type='function_call', ) def _map_json_schema(self, o: OutputObjectDefinition) -> responses.ResponseFormatTextJSONSchemaConfigParam: response_format_param: responses.ResponseFormatTextJSONSchemaConfigParam = { 'type': 'json_schema', 'name': o.name or DEFAULT_OUTPUT_TOOL_NAME, 'schema': o.json_schema, } if o.description: response_format_param['description'] = o.description if OpenAIModelProfile.from_profile(self.profile).openai_supports_strict_tool_definition: # pragma: no branch response_format_param['strict'] = o.strict return response_format_param @staticmethod async def _map_user_prompt(part: UserPromptPart) -> responses.EasyInputMessageParam: content: str | list[responses.ResponseInputContentParam] if isinstance(part.content, str): content = part.content else: content = [] for item in part.content: if isinstance(item, str): content.append(responses.ResponseInputTextParam(text=item, type='input_text')) elif isinstance(item, BinaryContent): base64_encoded = base64.b64encode(item.data).decode('utf-8') if item.is_image: content.append( responses.ResponseInputImageParam( image_url=f'data:{item.media_type};base64,{base64_encoded}', type='input_image', detail='auto', ) ) elif item.is_document: content.append( responses.ResponseInputFileParam( type='input_file', file_data=f'data:{item.media_type};base64,{base64_encoded}', # NOTE: Type wise it's not necessary to include the filename, but it's required by the # API itself. If we add empty string, the server sends a 500 error - which OpenAI needs # to fix. In any case, we add a placeholder name. filename=f'filename.{item.format}', ) ) elif item.is_audio: raise NotImplementedError('Audio as binary content is not supported for OpenAI Responses API.') else: # pragma: no cover raise RuntimeError(f'Unsupported binary content type: {item.media_type}') elif isinstance(item, ImageUrl): content.append( responses.ResponseInputImageParam(image_url=item.url, type='input_image', detail='auto') ) elif isinstance(item, AudioUrl): # pragma: no cover downloaded_item = await download_item(item, data_format='base64_uri', type_format='extension') content.append( responses.ResponseInputFileParam( type='input_file', file_data=downloaded_item['data'], filename=f'filename.{downloaded_item["data_type"]}', ) ) elif isinstance(item, DocumentUrl): downloaded_item = await download_item(item, data_format='base64_uri', type_format='extension') content.append( responses.ResponseInputFileParam( type='input_file', file_data=downloaded_item['data'], filename=f'filename.{downloaded_item["data_type"]}', ) ) elif isinstance(item, VideoUrl): # pragma: no cover raise NotImplementedError('VideoUrl is not supported for OpenAI.') else: assert_never(item) return responses.EasyInputMessageParam(role='user', content=content) ``` #### __init__ ```python __init__( model_name: OpenAIModelName, *, provider: ( Literal[ "openai", "deepseek", "azure", "openrouter", "grok", "fireworks", "together", ] | Provider[AsyncOpenAI] ) = "openai", profile: ModelProfileSpec | None = None ) ``` Initialize an OpenAI Responses model. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `model_name` | `OpenAIModelName` | The name of the OpenAI model to use. | *required* | | `provider` | `Literal['openai', 'deepseek', 'azure', 'openrouter', 'grok', 'fireworks', 'together'] | Provider[AsyncOpenAI]` | The provider to use. Defaults to 'openai'. | `'openai'` | | `profile` | `ModelProfileSpec | None` | The model profile to use. Defaults to a profile picked by the provider based on the model name. | `None` | Source code in `pydantic_ai_slim/pydantic_ai/models/openai.py` ```python def __init__( self, model_name: OpenAIModelName, *, provider: Literal['openai', 'deepseek', 'azure', 'openrouter', 'grok', 'fireworks', 'together'] | Provider[AsyncOpenAI] = 'openai', profile: ModelProfileSpec | None = None, ): """Initialize an OpenAI Responses model. Args: model_name: The name of the OpenAI model to use. provider: The provider to use. Defaults to `'openai'`. profile: The model profile to use. Defaults to a profile picked by the provider based on the model name. """ self._model_name = model_name if isinstance(provider, str): provider = infer_provider(provider) self.client = provider.client self._profile = profile or provider.model_profile ``` #### model_name ```python model_name: OpenAIModelName ``` The model name. #### system ```python system: str ``` The system / model provider. # `pydantic_ai.models.test` Utility model for quickly testing apps built with PydanticAI. Here's a minimal example: test_model_usage.py ```py from pydantic_ai import Agent from pydantic_ai.models.test import TestModel my_agent = Agent('openai:gpt-4o', system_prompt='...') async def test_my_agent(): """Unit test for my_agent, to be run by pytest.""" m = TestModel() with my_agent.override(model=m): result = await my_agent.run('Testing my agent...') assert result.output == 'success (no tool calls)' assert m.last_model_request_parameters.function_tools == [] ``` See [Unit testing with `TestModel`](../../../testing/#unit-testing-with-testmodel) for detailed documentation. ### TestModel Bases: `Model` A model specifically for testing purposes. This will (by default) call all tools in the agent, then return a tool response if possible, otherwise a plain response. How useful this model is will vary significantly. Apart from `__init__` derived by the `dataclass` decorator, all methods are private or match those of the base class. Source code in `pydantic_ai_slim/pydantic_ai/models/test.py` ```python @dataclass class TestModel(Model): """A model specifically for testing purposes. This will (by default) call all tools in the agent, then return a tool response if possible, otherwise a plain response. How useful this model is will vary significantly. Apart from `__init__` derived by the `dataclass` decorator, all methods are private or match those of the base class. """ # NOTE: Avoid test discovery by pytest. __test__ = False call_tools: list[str] | Literal['all'] = 'all' """List of tools to call. If `'all'`, all tools will be called.""" custom_output_text: str | None = None """If set, this text is returned as the final output.""" custom_output_args: Any | None = None """If set, these args will be passed to the output tool.""" seed: int = 0 """Seed for generating random data.""" last_model_request_parameters: ModelRequestParameters | None = field(default=None, init=False) """The last ModelRequestParameters passed to the model in a request. The ModelRequestParameters contains information about the function and output tools available during request handling. This is set when a request is made, so will reflect the function tools from the last step of the last run. """ _model_name: str = field(default='test', repr=False) _system: str = field(default='test', repr=False) async def request( self, messages: list[ModelMessage], model_settings: ModelSettings | None, model_request_parameters: ModelRequestParameters, ) -> ModelResponse: self.last_model_request_parameters = model_request_parameters model_response = self._request(messages, model_settings, model_request_parameters) model_response.usage = _estimate_usage([*messages, model_response]) model_response.usage.requests = 1 return model_response @asynccontextmanager async def request_stream( self, messages: list[ModelMessage], model_settings: ModelSettings | None, model_request_parameters: ModelRequestParameters, ) -> AsyncIterator[StreamedResponse]: self.last_model_request_parameters = model_request_parameters model_response = self._request(messages, model_settings, model_request_parameters) yield TestStreamedResponse( _model_name=self._model_name, _structured_response=model_response, _messages=messages ) @property def model_name(self) -> str: """The model name.""" return self._model_name @property def system(self) -> str: """The system / model provider.""" return self._system def gen_tool_args(self, tool_def: ToolDefinition) -> Any: return _JsonSchemaTestData(tool_def.parameters_json_schema, self.seed).generate() def _get_tool_calls(self, model_request_parameters: ModelRequestParameters) -> list[tuple[str, ToolDefinition]]: if self.call_tools == 'all': return [(r.name, r) for r in model_request_parameters.function_tools] else: function_tools_lookup = {t.name: t for t in model_request_parameters.function_tools} tools_to_call = (function_tools_lookup[name] for name in self.call_tools) return [(r.name, r) for r in tools_to_call] def _get_output(self, model_request_parameters: ModelRequestParameters) -> _WrappedTextOutput | _WrappedToolOutput: if self.custom_output_text is not None: assert model_request_parameters.output_mode != 'tool', ( 'Plain response not allowed, but `custom_output_text` is set.' ) assert self.custom_output_args is None, 'Cannot set both `custom_output_text` and `custom_output_args`.' return _WrappedTextOutput(self.custom_output_text) elif self.custom_output_args is not None: assert model_request_parameters.output_tools is not None, ( 'No output tools provided, but `custom_output_args` is set.' ) output_tool = model_request_parameters.output_tools[0] if k := output_tool.outer_typed_dict_key: return _WrappedToolOutput({k: self.custom_output_args}) else: return _WrappedToolOutput(self.custom_output_args) elif model_request_parameters.allow_text_output: return _WrappedTextOutput(None) elif model_request_parameters.output_tools: return _WrappedToolOutput(None) else: return _WrappedTextOutput(None) def _request( self, messages: list[ModelMessage], model_settings: ModelSettings | None, model_request_parameters: ModelRequestParameters, ) -> ModelResponse: tool_calls = self._get_tool_calls(model_request_parameters) output_wrapper = self._get_output(model_request_parameters) output_tools = model_request_parameters.output_tools # if there are tools, the first thing we want to do is call all of them if tool_calls and not any(isinstance(m, ModelResponse) for m in messages): return ModelResponse( parts=[ToolCallPart(name, self.gen_tool_args(args)) for name, args in tool_calls], model_name=self._model_name, ) if messages: # pragma: no branch last_message = messages[-1] assert isinstance(last_message, ModelRequest), 'Expected last message to be a `ModelRequest`.' # check if there are any retry prompts, if so retry them new_retry_names = {p.tool_name for p in last_message.parts if isinstance(p, RetryPromptPart)} if new_retry_names: # Handle retries for both function tools and output tools # Check function tools first retry_parts: list[ModelResponsePart] = [ ToolCallPart(name, self.gen_tool_args(args)) for name, args in tool_calls if name in new_retry_names ] # Check output tools if output_tools: retry_parts.extend( [ ToolCallPart( tool.name, output_wrapper.value if isinstance(output_wrapper, _WrappedToolOutput) and output_wrapper.value is not None else self.gen_tool_args(tool), ) for tool in output_tools if tool.name in new_retry_names ] ) return ModelResponse(parts=retry_parts, model_name=self._model_name) if isinstance(output_wrapper, _WrappedTextOutput): if (response_text := output_wrapper.value) is None: # build up details of tool responses output: dict[str, Any] = {} for message in messages: if isinstance(message, ModelRequest): for part in message.parts: if isinstance(part, ToolReturnPart): output[part.tool_name] = part.content if output: return ModelResponse( parts=[TextPart(pydantic_core.to_json(output).decode())], model_name=self._model_name ) else: return ModelResponse(parts=[TextPart('success (no tool calls)')], model_name=self._model_name) else: return ModelResponse(parts=[TextPart(response_text)], model_name=self._model_name) else: assert output_tools, 'No output tools provided' custom_output_args = output_wrapper.value output_tool = output_tools[self.seed % len(output_tools)] if custom_output_args is not None: return ModelResponse( parts=[ToolCallPart(output_tool.name, custom_output_args)], model_name=self._model_name ) else: response_args = self.gen_tool_args(output_tool) return ModelResponse(parts=[ToolCallPart(output_tool.name, response_args)], model_name=self._model_name) ``` #### call_tools ```python call_tools: list[str] | Literal['all'] = 'all' ``` List of tools to call. If `'all'`, all tools will be called. #### custom_output_text ```python custom_output_text: str | None = None ``` If set, this text is returned as the final output. #### custom_output_args ```python custom_output_args: Any | None = None ``` If set, these args will be passed to the output tool. #### seed ```python seed: int = 0 ``` Seed for generating random data. #### last_model_request_parameters ```python last_model_request_parameters: ( ModelRequestParameters | None ) = field(default=None, init=False) ``` The last ModelRequestParameters passed to the model in a request. The ModelRequestParameters contains information about the function and output tools available during request handling. This is set when a request is made, so will reflect the function tools from the last step of the last run. #### model_name ```python model_name: str ``` The model name. #### system ```python system: str ``` The system / model provider. ### TestStreamedResponse Bases: `StreamedResponse` A structured response that streams test data. Source code in `pydantic_ai_slim/pydantic_ai/models/test.py` ```python @dataclass class TestStreamedResponse(StreamedResponse): """A structured response that streams test data.""" _model_name: str _structured_response: ModelResponse _messages: InitVar[Iterable[ModelMessage]] _timestamp: datetime = field(default_factory=_utils.now_utc, init=False) def __post_init__(self, _messages: Iterable[ModelMessage]): self._usage = _estimate_usage(_messages) async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: for i, part in enumerate(self._structured_response.parts): if isinstance(part, TextPart): text = part.content *words, last_word = text.split(' ') words = [f'{word} ' for word in words] words.append(last_word) if len(words) == 1 and len(text) > 2: mid = len(text) // 2 words = [text[:mid], text[mid:]] self._usage += _get_string_usage('') yield self._parts_manager.handle_text_delta(vendor_part_id=i, content='') for word in words: self._usage += _get_string_usage(word) yield self._parts_manager.handle_text_delta(vendor_part_id=i, content=word) elif isinstance(part, ToolCallPart): yield self._parts_manager.handle_tool_call_part( vendor_part_id=i, tool_name=part.tool_name, args=part.args, tool_call_id=part.tool_call_id ) elif isinstance(part, ThinkingPart): # pragma: no cover # NOTE: There's no way to reach this part of the code, since we don't generate ThinkingPart on TestModel. assert False, "This should be unreachable — we don't generate ThinkingPart on TestModel." else: assert_never(part) @property def model_name(self) -> str: """Get the model name of the response.""" return self._model_name @property def timestamp(self) -> datetime: """Get the timestamp of the response.""" return self._timestamp ``` #### model_name ```python model_name: str ``` Get the model name of the response. #### timestamp ```python timestamp: datetime ``` Get the timestamp of the response. # pydantic_ai.models.wrapper ### WrapperModel Bases: `Model` Model which wraps another model. Does nothing on its own, used as a base class. Source code in `pydantic_ai_slim/pydantic_ai/models/wrapper.py` ```python @dataclass(init=False) class WrapperModel(Model): """Model which wraps another model. Does nothing on its own, used as a base class. """ wrapped: Model """The underlying model being wrapped.""" def __init__(self, wrapped: Model | KnownModelName): self.wrapped = infer_model(wrapped) async def request(self, *args: Any, **kwargs: Any) -> ModelResponse: return await self.wrapped.request(*args, **kwargs) @asynccontextmanager async def request_stream( self, messages: list[ModelMessage], model_settings: ModelSettings | None, model_request_parameters: ModelRequestParameters, ) -> AsyncIterator[StreamedResponse]: async with self.wrapped.request_stream(messages, model_settings, model_request_parameters) as response_stream: yield response_stream def customize_request_parameters(self, model_request_parameters: ModelRequestParameters) -> ModelRequestParameters: return self.wrapped.customize_request_parameters(model_request_parameters) @property def model_name(self) -> str: return self.wrapped.model_name @property def system(self) -> str: return self.wrapped.system @cached_property def profile(self) -> ModelProfile: return self.wrapped.profile def __getattr__(self, item: str): return getattr(self.wrapped, item) # pragma: no cover ``` #### wrapped ```python wrapped: Model = infer_model(wrapped) ``` The underlying model being wrapped. # `pydantic_evals.dataset` Dataset management for pydantic evals. This module provides functionality for creating, loading, saving, and evaluating datasets of test cases. Each case must have inputs, and can optionally have a name, expected output, metadata, and case-specific evaluators. Datasets can be loaded from and saved to YAML or JSON files, and can be evaluated against a task function to produce an evaluation report. ### Case Bases: `Generic[InputsT, OutputT, MetadataT]` A single row of a Dataset. Each case represents a single test scenario with inputs to test. A case may optionally specify a name, expected outputs to compare against, and arbitrary metadata. Cases can also have their own specific evaluators which are run in addition to dataset-level evaluators. Example: ```python from pydantic_evals import Case case = Case( name='Simple addition', inputs={'a': 1, 'b': 2}, expected_output=3, metadata={'description': 'Tests basic addition'}, ) ``` Source code in `pydantic_evals/pydantic_evals/dataset.py` ````python @dataclass(init=False) class Case(Generic[InputsT, OutputT, MetadataT]): """A single row of a [`Dataset`][pydantic_evals.Dataset]. Each case represents a single test scenario with inputs to test. A case may optionally specify a name, expected outputs to compare against, and arbitrary metadata. Cases can also have their own specific evaluators which are run in addition to dataset-level evaluators. Example: ```python from pydantic_evals import Case case = Case( name='Simple addition', inputs={'a': 1, 'b': 2}, expected_output=3, metadata={'description': 'Tests basic addition'}, ) ``` """ name: str | None """Name of the case. This is used to identify the case in the report and can be used to filter cases.""" inputs: InputsT """Inputs to the task. This is the input to the task that will be evaluated.""" metadata: MetadataT | None = None """Metadata to be used in the evaluation. This can be used to provide additional information about the case to the evaluators. """ expected_output: OutputT | None = None """Expected output of the task. This is the expected output of the task that will be evaluated.""" evaluators: list[Evaluator[InputsT, OutputT, MetadataT]] = field(default_factory=list) """Evaluators to be used just on this case.""" def __init__( self, *, name: str | None = None, inputs: InputsT, metadata: MetadataT | None = None, expected_output: OutputT | None = None, evaluators: tuple[Evaluator[InputsT, OutputT, MetadataT], ...] = (), ): """Initialize a new test case. Args: name: Optional name for the case. If not provided, a generic name will be assigned when added to a dataset. inputs: The inputs to the task being evaluated. metadata: Optional metadata for the case, which can be used by evaluators. expected_output: Optional expected output of the task, used for comparison in evaluators. evaluators: Tuple of evaluators specific to this case. These are in addition to any dataset-level evaluators. """ # Note: `evaluators` must be a tuple instead of Sequence due to misbehavior with pyright's generic parameter # inference if it has type `Sequence` self.name = name self.inputs = inputs self.metadata = metadata self.expected_output = expected_output self.evaluators = list(evaluators) ```` #### __init__ ```python __init__( *, name: str | None = None, inputs: InputsT, metadata: MetadataT | None = None, expected_output: OutputT | None = None, evaluators: tuple[ Evaluator[InputsT, OutputT, MetadataT], ... ] = () ) ``` Initialize a new test case. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `name` | `str | None` | Optional name for the case. If not provided, a generic name will be assigned when added to a dataset. | `None` | | `inputs` | `InputsT` | The inputs to the task being evaluated. | *required* | | `metadata` | `MetadataT | None` | Optional metadata for the case, which can be used by evaluators. | `None` | | `expected_output` | `OutputT | None` | Optional expected output of the task, used for comparison in evaluators. | `None` | | `evaluators` | `tuple[Evaluator[InputsT, OutputT, MetadataT], ...]` | Tuple of evaluators specific to this case. These are in addition to any dataset-level evaluators. | `()` | Source code in `pydantic_evals/pydantic_evals/dataset.py` ```python def __init__( self, *, name: str | None = None, inputs: InputsT, metadata: MetadataT | None = None, expected_output: OutputT | None = None, evaluators: tuple[Evaluator[InputsT, OutputT, MetadataT], ...] = (), ): """Initialize a new test case. Args: name: Optional name for the case. If not provided, a generic name will be assigned when added to a dataset. inputs: The inputs to the task being evaluated. metadata: Optional metadata for the case, which can be used by evaluators. expected_output: Optional expected output of the task, used for comparison in evaluators. evaluators: Tuple of evaluators specific to this case. These are in addition to any dataset-level evaluators. """ # Note: `evaluators` must be a tuple instead of Sequence due to misbehavior with pyright's generic parameter # inference if it has type `Sequence` self.name = name self.inputs = inputs self.metadata = metadata self.expected_output = expected_output self.evaluators = list(evaluators) ``` #### name ```python name: str | None = name ``` Name of the case. This is used to identify the case in the report and can be used to filter cases. #### inputs ```python inputs: InputsT = inputs ``` Inputs to the task. This is the input to the task that will be evaluated. #### metadata ```python metadata: MetadataT | None = metadata ``` Metadata to be used in the evaluation. This can be used to provide additional information about the case to the evaluators. #### expected_output ```python expected_output: OutputT | None = expected_output ``` Expected output of the task. This is the expected output of the task that will be evaluated. #### evaluators ```python evaluators: list[Evaluator[InputsT, OutputT, MetadataT]] = ( list(evaluators) ) ``` Evaluators to be used just on this case. ### Dataset Bases: `BaseModel`, `Generic[InputsT, OutputT, MetadataT]` A dataset of test cases. Datasets allow you to organize a collection of test cases and evaluate them against a task function. They can be loaded from and saved to YAML or JSON files, and can have dataset-level evaluators that apply to all cases. Example: ```python # Create a dataset with two test cases from dataclasses import dataclass from pydantic_evals import Case, Dataset from pydantic_evals.evaluators import Evaluator, EvaluatorContext @dataclass class ExactMatch(Evaluator): def evaluate(self, ctx: EvaluatorContext) -> bool: return ctx.output == ctx.expected_output dataset = Dataset( cases=[ Case(name='test1', inputs={'text': 'Hello'}, expected_output='HELLO'), Case(name='test2', inputs={'text': 'World'}, expected_output='WORLD'), ], evaluators=[ExactMatch()], ) # Evaluate the dataset against a task function async def uppercase(inputs: dict) -> str: return inputs['text'].upper() async def main(): report = await dataset.evaluate(uppercase) report.print() ''' Evaluation Summary: uppercase ┏━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━┓ ┃ Case ID ┃ Assertions ┃ Duration ┃ ┡━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━┩ │ test1 │ ✔ │ 10ms │ ├──────────┼────────────┼──────────┤ │ test2 │ ✔ │ 10ms │ ├──────────┼────────────┼──────────┤ │ Averages │ 100.0% ✔ │ 10ms │ └──────────┴────────────┴──────────┘ ''' ``` Source code in `pydantic_evals/pydantic_evals/dataset.py` ````python class Dataset(BaseModel, Generic[InputsT, OutputT, MetadataT], extra='forbid', arbitrary_types_allowed=True): """A dataset of test [cases][pydantic_evals.Case]. Datasets allow you to organize a collection of test cases and evaluate them against a task function. They can be loaded from and saved to YAML or JSON files, and can have dataset-level evaluators that apply to all cases. Example: ```python # Create a dataset with two test cases from dataclasses import dataclass from pydantic_evals import Case, Dataset from pydantic_evals.evaluators import Evaluator, EvaluatorContext @dataclass class ExactMatch(Evaluator): def evaluate(self, ctx: EvaluatorContext) -> bool: return ctx.output == ctx.expected_output dataset = Dataset( cases=[ Case(name='test1', inputs={'text': 'Hello'}, expected_output='HELLO'), Case(name='test2', inputs={'text': 'World'}, expected_output='WORLD'), ], evaluators=[ExactMatch()], ) # Evaluate the dataset against a task function async def uppercase(inputs: dict) -> str: return inputs['text'].upper() async def main(): report = await dataset.evaluate(uppercase) report.print() ''' Evaluation Summary: uppercase ┏━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━┓ ┃ Case ID ┃ Assertions ┃ Duration ┃ ┡━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━┩ │ test1 │ ✔ │ 10ms │ ├──────────┼────────────┼──────────┤ │ test2 │ ✔ │ 10ms │ ├──────────┼────────────┼──────────┤ │ Averages │ 100.0% ✔ │ 10ms │ └──────────┴────────────┴──────────┘ ''' ``` """ cases: list[Case[InputsT, OutputT, MetadataT]] """List of test cases in the dataset.""" evaluators: list[Evaluator[InputsT, OutputT, MetadataT]] = [] """List of evaluators to be used on all cases in the dataset.""" def __init__( self, *, cases: Sequence[Case[InputsT, OutputT, MetadataT]], evaluators: Sequence[Evaluator[InputsT, OutputT, MetadataT]] = (), ): """Initialize a new dataset with test cases and optional evaluators. Args: cases: Sequence of test cases to include in the dataset. evaluators: Optional sequence of evaluators to apply to all cases in the dataset. """ case_names = set[str]() for case in cases: if case.name is None: continue if case.name in case_names: raise ValueError(f'Duplicate case name: {case.name!r}') case_names.add(case.name) super().__init__( cases=cases, evaluators=list(evaluators), ) async def evaluate( self, task: Callable[[InputsT], Awaitable[OutputT]], name: str | None = None, max_concurrency: int | None = None, progress: bool = True, ) -> EvaluationReport: """Evaluates the test cases in the dataset using the given task. This method runs the task on each case in the dataset, applies evaluators, and collects results into a report. Cases are run concurrently, limited by `max_concurrency` if specified. Args: task: The task to evaluate. This should be a callable that takes the inputs of the case and returns the output. name: The name of the task being evaluated, this is used to identify the task in the report. If omitted, the name of the task function will be used. max_concurrency: The maximum number of concurrent evaluations of the task to allow. If None, all cases will be evaluated concurrently. progress: Whether to show a progress bar for the evaluation. Defaults to `True`. Returns: A report containing the results of the evaluation. """ name = name or get_unwrapped_function_name(task) total_cases = len(self.cases) progress_bar = Progress() if progress else None limiter = anyio.Semaphore(max_concurrency) if max_concurrency is not None else AsyncExitStack() with _logfire.span('evaluate {name}', name=name) as eval_span, progress_bar or nullcontext(): task_id = progress_bar.add_task(f'Evaluating {name}', total=total_cases) if progress_bar else None async def _handle_case(case: Case[InputsT, OutputT, MetadataT], report_case_name: str): async with limiter: result = await _run_task_and_evaluators(task, case, report_case_name, self.evaluators) if progress_bar and task_id is not None: # pragma: no branch progress_bar.update(task_id, advance=1) return result report = EvaluationReport( name=name, cases=await task_group_gather( [ lambda case=case, i=i: _handle_case(case, case.name or f'Case {i}') for i, case in enumerate(self.cases, 1) ] ), ) # TODO(DavidM): This attribute will be too big in general; remove it once we can use child spans in details panel: eval_span.set_attribute('cases', report.cases) # TODO(DavidM): Remove this 'averages' attribute once we compute it in the details panel eval_span.set_attribute('averages', report.averages()) return report def evaluate_sync( self, task: Callable[[InputsT], Awaitable[OutputT]], name: str | None = None, max_concurrency: int | None = None, progress: bool = True, ) -> EvaluationReport: """Evaluates the test cases in the dataset using the given task. This is a synchronous wrapper around [`evaluate`][pydantic_evals.Dataset.evaluate] provided for convenience. Args: task: The task to evaluate. This should be a callable that takes the inputs of the case and returns the output. name: The name of the task being evaluated, this is used to identify the task in the report. If omitted, the name of the task function will be used. max_concurrency: The maximum number of concurrent evaluations of the task to allow. If None, all cases will be evaluated concurrently. progress: Whether to show a progress bar for the evaluation. Defaults to True. Returns: A report containing the results of the evaluation. """ return get_event_loop().run_until_complete( self.evaluate(task, name=name, max_concurrency=max_concurrency, progress=progress) ) def add_case( self, *, name: str | None = None, inputs: InputsT, metadata: MetadataT | None = None, expected_output: OutputT | None = None, evaluators: tuple[Evaluator[InputsT, OutputT, MetadataT], ...] = (), ) -> None: """Adds a case to the dataset. This is a convenience method for creating a [`Case`][pydantic_evals.Case] and adding it to the dataset. Args: name: Optional name for the case. If not provided, a generic name will be assigned. inputs: The inputs to the task being evaluated. metadata: Optional metadata for the case, which can be used by evaluators. expected_output: The expected output of the task, used for comparison in evaluators. evaluators: Tuple of evaluators specific to this case, in addition to dataset-level evaluators. """ if name in {case.name for case in self.cases}: raise ValueError(f'Duplicate case name: {name!r}') case = Case[InputsT, OutputT, MetadataT]( name=name, inputs=inputs, metadata=metadata, expected_output=expected_output, evaluators=evaluators, ) self.cases.append(case) def add_evaluator( self, evaluator: Evaluator[InputsT, OutputT, MetadataT], specific_case: str | None = None, ) -> None: """Adds an evaluator to the dataset or a specific case. Args: evaluator: The evaluator to add. specific_case: If provided, the evaluator will only be added to the case with this name. If None, the evaluator will be added to all cases in the dataset. Raises: ValueError: If `specific_case` is provided but no case with that name exists in the dataset. """ if specific_case is None: self.evaluators.append(evaluator) else: # If this is too slow, we could try to add a case lookup dict. # Note that if we do that, we'd need to make the cases list private to prevent modification. added = False for case in self.cases: if case.name == specific_case: case.evaluators.append(evaluator) added = True if not added: raise ValueError(f'Case {specific_case!r} not found in the dataset') @classmethod @functools.cache def _params(cls) -> tuple[type[InputsT], type[OutputT], type[MetadataT]]: """Get the type parameters for the Dataset class. Returns: A tuple of (InputsT, OutputT, MetadataT) types. """ for c in cls.__mro__: metadata = getattr(c, '__pydantic_generic_metadata__', {}) if len(args := (metadata.get('args', ()) or getattr(c, '__args__', ()))) == 3: # pragma: no branch return args else: # pragma: no cover warnings.warn( f'Could not determine the generic parameters for {cls}; using `Any` for each.' f' You should explicitly set the generic parameters via `Dataset[MyInputs, MyOutput, MyMetadata]`' f' when serializing or deserializing.', UserWarning, ) return Any, Any, Any # type: ignore @classmethod def from_file( cls, path: Path | str, fmt: Literal['yaml', 'json'] | None = None, custom_evaluator_types: Sequence[type[Evaluator[InputsT, OutputT, MetadataT]]] = (), ) -> Self: """Load a dataset from a file. Args: path: Path to the file to load. fmt: Format of the file. If None, the format will be inferred from the file extension. Must be either 'yaml' or 'json'. custom_evaluator_types: Custom evaluator classes to use when deserializing the dataset. These are additional evaluators beyond the default ones. Returns: A new Dataset instance loaded from the file. Raises: ValidationError: If the file cannot be parsed as a valid dataset. ValueError: If the format cannot be inferred from the file extension. """ path = Path(path) fmt = cls._infer_fmt(path, fmt) raw = Path(path).read_text() try: return cls.from_text(raw, fmt=fmt, custom_evaluator_types=custom_evaluator_types) except ValidationError as e: # pragma: no cover raise ValueError(f'{path} contains data that does not match the schema for {cls.__name__}:\n{e}.') from e @classmethod def from_text( cls, contents: str, fmt: Literal['yaml', 'json'] = 'yaml', custom_evaluator_types: Sequence[type[Evaluator[InputsT, OutputT, MetadataT]]] = (), ) -> Self: """Load a dataset from a string. Args: contents: The string content to parse. fmt: Format of the content. Must be either 'yaml' or 'json'. custom_evaluator_types: Custom evaluator classes to use when deserializing the dataset. These are additional evaluators beyond the default ones. Returns: A new Dataset instance parsed from the string. Raises: ValidationError: If the content cannot be parsed as a valid dataset. """ if fmt == 'yaml': loaded = yaml.safe_load(contents) return cls.from_dict(loaded, custom_evaluator_types) else: dataset_model_type = cls._serialization_type() dataset_model = dataset_model_type.model_validate_json(contents) return cls._from_dataset_model(dataset_model, custom_evaluator_types) @classmethod def from_dict( cls, data: dict[str, Any], custom_evaluator_types: Sequence[type[Evaluator[InputsT, OutputT, MetadataT]]] = (), ) -> Self: """Load a dataset from a dictionary. Args: data: Dictionary representation of the dataset. custom_evaluator_types: Custom evaluator classes to use when deserializing the dataset. These are additional evaluators beyond the default ones. Returns: A new Dataset instance created from the dictionary. Raises: ValidationError: If the dictionary cannot be converted to a valid dataset. """ dataset_model_type = cls._serialization_type() dataset_model = dataset_model_type.model_validate(data) return cls._from_dataset_model(dataset_model, custom_evaluator_types) @classmethod def _from_dataset_model( cls, dataset_model: _DatasetModel[InputsT, OutputT, MetadataT], custom_evaluator_types: Sequence[type[Evaluator[InputsT, OutputT, MetadataT]]] = (), ) -> Self: """Create a Dataset from a _DatasetModel. Args: dataset_model: The _DatasetModel to convert. custom_evaluator_types: Custom evaluator classes to register for deserialization. Returns: A new Dataset instance created from the _DatasetModel. """ registry = _get_registry(custom_evaluator_types) cases: list[Case[InputsT, OutputT, MetadataT]] = [] errors: list[ValueError] = [] dataset_evaluators: list[Evaluator] = [] for spec in dataset_model.evaluators: try: dataset_evaluator = _load_evaluator_from_registry(registry, None, spec) except ValueError as e: errors.append(e) continue dataset_evaluators.append(dataset_evaluator) for row in dataset_model.cases: evaluators: list[Evaluator] = [] for spec in row.evaluators: try: evaluator = _load_evaluator_from_registry(registry, row.name, spec) except ValueError as e: errors.append(e) continue evaluators.append(evaluator) row = Case[InputsT, OutputT, MetadataT]( name=row.name, inputs=row.inputs, metadata=row.metadata, expected_output=row.expected_output, ) row.evaluators = evaluators cases.append(row) if errors: raise ExceptionGroup(f'{len(errors)} error(s) loading evaluators from registry', errors[:3]) result = cls(cases=cases) result.evaluators = dataset_evaluators return result def to_file( self, path: Path | str, fmt: Literal['yaml', 'json'] | None = None, schema_path: Path | str | None = DEFAULT_SCHEMA_PATH_TEMPLATE, custom_evaluator_types: Sequence[type[Evaluator[InputsT, OutputT, MetadataT]]] = (), ): """Save the dataset to a file. Args: path: Path to save the dataset to. fmt: Format to use. If None, the format will be inferred from the file extension. Must be either 'yaml' or 'json'. schema_path: Path to save the JSON schema to. If None, no schema will be saved. Can be a string template with {stem} which will be replaced with the dataset filename stem. custom_evaluator_types: Custom evaluator classes to include in the schema. """ path = Path(path) fmt = self._infer_fmt(path, fmt) schema_ref: str | None = None if schema_path is not None: # pragma: no branch if isinstance(schema_path, str): # pragma: no branch schema_path = Path(schema_path.format(stem=path.stem)) if not schema_path.is_absolute(): schema_ref = str(schema_path) schema_path = path.parent / schema_path elif schema_path.is_relative_to(path): # pragma: no cover schema_ref = str(_get_relative_path_reference(schema_path, path)) else: # pragma: no cover schema_ref = str(schema_path) self._save_schema(schema_path, custom_evaluator_types) context: dict[str, Any] = {'use_short_form': True} if fmt == 'yaml': dumped_data = self.model_dump(mode='json', by_alias=True, exclude_defaults=True, context=context) content = yaml.dump(dumped_data, sort_keys=False) if schema_ref: # pragma: no branch yaml_language_server_line = f'{_YAML_SCHEMA_LINE_PREFIX}{schema_ref}' content = f'{yaml_language_server_line}\n{content}' path.write_text(content) else: context['$schema'] = schema_ref json_data = self.model_dump_json(indent=2, by_alias=True, exclude_defaults=True, context=context) path.write_text(json_data + '\n') @classmethod def model_json_schema_with_evaluators( cls, custom_evaluator_types: Sequence[type[Evaluator[InputsT, OutputT, MetadataT]]] = (), ) -> dict[str, Any]: """Generate a JSON schema for this dataset type, including evaluator details. This is useful for generating a schema that can be used to validate YAML-format dataset files. Args: custom_evaluator_types: Custom evaluator classes to include in the schema. Returns: A dictionary representing the JSON schema. """ # Note: this function could maybe be simplified now that Evaluators are always dataclasses registry = _get_registry(custom_evaluator_types) evaluator_schema_types: list[Any] = [] for name, evaluator_class in registry.items(): type_hints = _typing_extra.get_function_type_hints(evaluator_class) type_hints.pop('return', None) required_type_hints: dict[str, Any] = {} for p in inspect.signature(evaluator_class).parameters.values(): type_hints.setdefault(p.name, Any) if p.default is not p.empty: type_hints[p.name] = NotRequired[type_hints[p.name]] else: required_type_hints[p.name] = type_hints[p.name] def _make_typed_dict(cls_name_prefix: str, fields: dict[str, Any]) -> Any: td = TypedDict(f'{cls_name_prefix}_{name}', fields) # pyright: ignore[reportArgumentType] config = ConfigDict(extra='forbid', arbitrary_types_allowed=True) # TODO: Replace with pydantic.with_config after pydantic 2.11 is released td.__pydantic_config__ = config # pyright: ignore[reportAttributeAccessIssue] return td # Shortest form: just the call name if len(type_hints) == 0 or not required_type_hints: evaluator_schema_types.append(Literal[name]) # Short form: can be called with only one parameter if len(type_hints) == 1: [type_hint_type] = type_hints.values() evaluator_schema_types.append(_make_typed_dict('short_evaluator', {name: type_hint_type})) elif len(required_type_hints) == 1: # pragma: no branch [type_hint_type] = required_type_hints.values() evaluator_schema_types.append(_make_typed_dict('short_evaluator', {name: type_hint_type})) # Long form: multiple parameters, possibly required if len(type_hints) > 1: params_td = _make_typed_dict('evaluator_params', type_hints) evaluator_schema_types.append(_make_typed_dict('evaluator', {name: params_td})) in_type, out_type, meta_type = cls._params() # Note: we shadow the `Case` and `Dataset` class names here to generate a clean JSON schema class Case(BaseModel, extra='forbid'): # pyright: ignore[reportUnusedClass] # this _is_ used below, but pyright doesn't seem to notice.. name: str | None = None inputs: in_type # pyright: ignore[reportInvalidTypeForm] metadata: meta_type | None = None # pyright: ignore[reportInvalidTypeForm,reportUnknownVariableType] expected_output: out_type | None = None # pyright: ignore[reportInvalidTypeForm,reportUnknownVariableType] if evaluator_schema_types: # pragma: no branch evaluators: list[Union[tuple(evaluator_schema_types)]] = [] # pyright: ignore # noqa UP007 class Dataset(BaseModel, extra='forbid'): cases: list[Case] if evaluator_schema_types: # pragma: no branch evaluators: list[Union[tuple(evaluator_schema_types)]] = [] # pyright: ignore # noqa UP007 json_schema = Dataset.model_json_schema() # See `_add_json_schema` below, since `$schema` is added to the JSON, it has to be supported in the JSON json_schema['properties']['$schema'] = {'type': 'string'} return json_schema @classmethod def _save_schema( cls, path: Path | str, custom_evaluator_types: Sequence[type[Evaluator[InputsT, OutputT, MetadataT]]] = () ): """Save the JSON schema for this dataset type to a file. Args: path: Path to save the schema to. custom_evaluator_types: Custom evaluator classes to include in the schema. """ path = Path(path) json_schema = cls.model_json_schema_with_evaluators(custom_evaluator_types) schema_content = to_json(json_schema, indent=2).decode() + '\n' if not path.exists() or path.read_text() != schema_content: # pragma: no branch path.write_text(schema_content) @classmethod @functools.cache def _serialization_type(cls) -> type[_DatasetModel[InputsT, OutputT, MetadataT]]: """Get the serialization type for this dataset class. Returns: A _DatasetModel type with the same generic parameters as this Dataset class. """ input_type, output_type, metadata_type = cls._params() return _DatasetModel[input_type, output_type, metadata_type] @classmethod def _infer_fmt(cls, path: Path, fmt: Literal['yaml', 'json'] | None) -> Literal['yaml', 'json']: """Infer the format to use for a file based on its extension. Args: path: The path to infer the format for. fmt: The explicitly provided format, if any. Returns: The inferred format ('yaml' or 'json'). Raises: ValueError: If the format cannot be inferred from the file extension. """ if fmt is not None: return fmt suffix = path.suffix.lower() if suffix in {'.yaml', '.yml'}: return 'yaml' elif suffix == '.json': return 'json' raise ValueError( f'Could not infer format for filename {path.name!r}. Use the `fmt` argument to specify the format.' ) @model_serializer(mode='wrap') def _add_json_schema(self, nxt: SerializerFunctionWrapHandler, info: SerializationInfo) -> dict[str, Any]: """Add the JSON schema path to the serialized output. See for context, that seems to be the nearest there is to a spec for this. """ context = cast(Union[dict[str, Any], None], info.context) if isinstance(context, dict) and (schema := context.get('$schema')): return {'$schema': schema} | nxt(self) else: return nxt(self) ```` #### cases ```python cases: list[Case[InputsT, OutputT, MetadataT]] ``` List of test cases in the dataset. #### evaluators ```python evaluators: list[Evaluator[InputsT, OutputT, MetadataT]] = ( [] ) ``` List of evaluators to be used on all cases in the dataset. #### __init__ ```python __init__( *, cases: Sequence[Case[InputsT, OutputT, MetadataT]], evaluators: Sequence[ Evaluator[InputsT, OutputT, MetadataT] ] = () ) ``` Initialize a new dataset with test cases and optional evaluators. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `cases` | `Sequence[Case[InputsT, OutputT, MetadataT]]` | Sequence of test cases to include in the dataset. | *required* | | `evaluators` | `Sequence[Evaluator[InputsT, OutputT, MetadataT]]` | Optional sequence of evaluators to apply to all cases in the dataset. | `()` | Source code in `pydantic_evals/pydantic_evals/dataset.py` ```python def __init__( self, *, cases: Sequence[Case[InputsT, OutputT, MetadataT]], evaluators: Sequence[Evaluator[InputsT, OutputT, MetadataT]] = (), ): """Initialize a new dataset with test cases and optional evaluators. Args: cases: Sequence of test cases to include in the dataset. evaluators: Optional sequence of evaluators to apply to all cases in the dataset. """ case_names = set[str]() for case in cases: if case.name is None: continue if case.name in case_names: raise ValueError(f'Duplicate case name: {case.name!r}') case_names.add(case.name) super().__init__( cases=cases, evaluators=list(evaluators), ) ``` #### evaluate ```python evaluate( task: Callable[[InputsT], Awaitable[OutputT]], name: str | None = None, max_concurrency: int | None = None, progress: bool = True, ) -> EvaluationReport ``` Evaluates the test cases in the dataset using the given task. This method runs the task on each case in the dataset, applies evaluators, and collects results into a report. Cases are run concurrently, limited by `max_concurrency` if specified. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `task` | `Callable[[InputsT], Awaitable[OutputT]]` | The task to evaluate. This should be a callable that takes the inputs of the case and returns the output. | *required* | | `name` | `str | None` | The name of the task being evaluated, this is used to identify the task in the report. If omitted, the name of the task function will be used. | `None` | | `max_concurrency` | `int | None` | The maximum number of concurrent evaluations of the task to allow. If None, all cases will be evaluated concurrently. | `None` | | `progress` | `bool` | Whether to show a progress bar for the evaluation. Defaults to True. | `True` | Returns: | Type | Description | | --- | --- | | `EvaluationReport` | A report containing the results of the evaluation. | Source code in `pydantic_evals/pydantic_evals/dataset.py` ```python async def evaluate( self, task: Callable[[InputsT], Awaitable[OutputT]], name: str | None = None, max_concurrency: int | None = None, progress: bool = True, ) -> EvaluationReport: """Evaluates the test cases in the dataset using the given task. This method runs the task on each case in the dataset, applies evaluators, and collects results into a report. Cases are run concurrently, limited by `max_concurrency` if specified. Args: task: The task to evaluate. This should be a callable that takes the inputs of the case and returns the output. name: The name of the task being evaluated, this is used to identify the task in the report. If omitted, the name of the task function will be used. max_concurrency: The maximum number of concurrent evaluations of the task to allow. If None, all cases will be evaluated concurrently. progress: Whether to show a progress bar for the evaluation. Defaults to `True`. Returns: A report containing the results of the evaluation. """ name = name or get_unwrapped_function_name(task) total_cases = len(self.cases) progress_bar = Progress() if progress else None limiter = anyio.Semaphore(max_concurrency) if max_concurrency is not None else AsyncExitStack() with _logfire.span('evaluate {name}', name=name) as eval_span, progress_bar or nullcontext(): task_id = progress_bar.add_task(f'Evaluating {name}', total=total_cases) if progress_bar else None async def _handle_case(case: Case[InputsT, OutputT, MetadataT], report_case_name: str): async with limiter: result = await _run_task_and_evaluators(task, case, report_case_name, self.evaluators) if progress_bar and task_id is not None: # pragma: no branch progress_bar.update(task_id, advance=1) return result report = EvaluationReport( name=name, cases=await task_group_gather( [ lambda case=case, i=i: _handle_case(case, case.name or f'Case {i}') for i, case in enumerate(self.cases, 1) ] ), ) # TODO(DavidM): This attribute will be too big in general; remove it once we can use child spans in details panel: eval_span.set_attribute('cases', report.cases) # TODO(DavidM): Remove this 'averages' attribute once we compute it in the details panel eval_span.set_attribute('averages', report.averages()) return report ``` #### evaluate_sync ```python evaluate_sync( task: Callable[[InputsT], Awaitable[OutputT]], name: str | None = None, max_concurrency: int | None = None, progress: bool = True, ) -> EvaluationReport ``` Evaluates the test cases in the dataset using the given task. This is a synchronous wrapper around evaluate provided for convenience. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `task` | `Callable[[InputsT], Awaitable[OutputT]]` | The task to evaluate. This should be a callable that takes the inputs of the case and returns the output. | *required* | | `name` | `str | None` | The name of the task being evaluated, this is used to identify the task in the report. If omitted, the name of the task function will be used. | `None` | | `max_concurrency` | `int | None` | The maximum number of concurrent evaluations of the task to allow. If None, all cases will be evaluated concurrently. | `None` | | `progress` | `bool` | Whether to show a progress bar for the evaluation. Defaults to True. | `True` | Returns: | Type | Description | | --- | --- | | `EvaluationReport` | A report containing the results of the evaluation. | Source code in `pydantic_evals/pydantic_evals/dataset.py` ```python def evaluate_sync( self, task: Callable[[InputsT], Awaitable[OutputT]], name: str | None = None, max_concurrency: int | None = None, progress: bool = True, ) -> EvaluationReport: """Evaluates the test cases in the dataset using the given task. This is a synchronous wrapper around [`evaluate`][pydantic_evals.Dataset.evaluate] provided for convenience. Args: task: The task to evaluate. This should be a callable that takes the inputs of the case and returns the output. name: The name of the task being evaluated, this is used to identify the task in the report. If omitted, the name of the task function will be used. max_concurrency: The maximum number of concurrent evaluations of the task to allow. If None, all cases will be evaluated concurrently. progress: Whether to show a progress bar for the evaluation. Defaults to True. Returns: A report containing the results of the evaluation. """ return get_event_loop().run_until_complete( self.evaluate(task, name=name, max_concurrency=max_concurrency, progress=progress) ) ``` #### add_case ```python add_case( *, name: str | None = None, inputs: InputsT, metadata: MetadataT | None = None, expected_output: OutputT | None = None, evaluators: tuple[ Evaluator[InputsT, OutputT, MetadataT], ... ] = () ) -> None ``` Adds a case to the dataset. This is a convenience method for creating a Case and adding it to the dataset. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `name` | `str | None` | Optional name for the case. If not provided, a generic name will be assigned. | `None` | | `inputs` | `InputsT` | The inputs to the task being evaluated. | *required* | | `metadata` | `MetadataT | None` | Optional metadata for the case, which can be used by evaluators. | `None` | | `expected_output` | `OutputT | None` | The expected output of the task, used for comparison in evaluators. | `None` | | `evaluators` | `tuple[Evaluator[InputsT, OutputT, MetadataT], ...]` | Tuple of evaluators specific to this case, in addition to dataset-level evaluators. | `()` | Source code in `pydantic_evals/pydantic_evals/dataset.py` ```python def add_case( self, *, name: str | None = None, inputs: InputsT, metadata: MetadataT | None = None, expected_output: OutputT | None = None, evaluators: tuple[Evaluator[InputsT, OutputT, MetadataT], ...] = (), ) -> None: """Adds a case to the dataset. This is a convenience method for creating a [`Case`][pydantic_evals.Case] and adding it to the dataset. Args: name: Optional name for the case. If not provided, a generic name will be assigned. inputs: The inputs to the task being evaluated. metadata: Optional metadata for the case, which can be used by evaluators. expected_output: The expected output of the task, used for comparison in evaluators. evaluators: Tuple of evaluators specific to this case, in addition to dataset-level evaluators. """ if name in {case.name for case in self.cases}: raise ValueError(f'Duplicate case name: {name!r}') case = Case[InputsT, OutputT, MetadataT]( name=name, inputs=inputs, metadata=metadata, expected_output=expected_output, evaluators=evaluators, ) self.cases.append(case) ``` #### add_evaluator ```python add_evaluator( evaluator: Evaluator[InputsT, OutputT, MetadataT], specific_case: str | None = None, ) -> None ``` Adds an evaluator to the dataset or a specific case. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `evaluator` | `Evaluator[InputsT, OutputT, MetadataT]` | The evaluator to add. | *required* | | `specific_case` | `str | None` | If provided, the evaluator will only be added to the case with this name. If None, the evaluator will be added to all cases in the dataset. | `None` | Raises: | Type | Description | | --- | --- | | `ValueError` | If specific_case is provided but no case with that name exists in the dataset. | Source code in `pydantic_evals/pydantic_evals/dataset.py` ```python def add_evaluator( self, evaluator: Evaluator[InputsT, OutputT, MetadataT], specific_case: str | None = None, ) -> None: """Adds an evaluator to the dataset or a specific case. Args: evaluator: The evaluator to add. specific_case: If provided, the evaluator will only be added to the case with this name. If None, the evaluator will be added to all cases in the dataset. Raises: ValueError: If `specific_case` is provided but no case with that name exists in the dataset. """ if specific_case is None: self.evaluators.append(evaluator) else: # If this is too slow, we could try to add a case lookup dict. # Note that if we do that, we'd need to make the cases list private to prevent modification. added = False for case in self.cases: if case.name == specific_case: case.evaluators.append(evaluator) added = True if not added: raise ValueError(f'Case {specific_case!r} not found in the dataset') ``` #### from_file ```python from_file( path: Path | str, fmt: Literal["yaml", "json"] | None = None, custom_evaluator_types: Sequence[ type[Evaluator[InputsT, OutputT, MetadataT]] ] = (), ) -> Self ``` Load a dataset from a file. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `path` | `Path | str` | Path to the file to load. | *required* | | `fmt` | `Literal['yaml', 'json'] | None` | Format of the file. If None, the format will be inferred from the file extension. Must be either 'yaml' or 'json'. | `None` | | `custom_evaluator_types` | `Sequence[type[Evaluator[InputsT, OutputT, MetadataT]]]` | Custom evaluator classes to use when deserializing the dataset. These are additional evaluators beyond the default ones. | `()` | Returns: | Type | Description | | --- | --- | | `Self` | A new Dataset instance loaded from the file. | Raises: | Type | Description | | --- | --- | | `ValidationError` | If the file cannot be parsed as a valid dataset. | | `ValueError` | If the format cannot be inferred from the file extension. | Source code in `pydantic_evals/pydantic_evals/dataset.py` ```python @classmethod def from_file( cls, path: Path | str, fmt: Literal['yaml', 'json'] | None = None, custom_evaluator_types: Sequence[type[Evaluator[InputsT, OutputT, MetadataT]]] = (), ) -> Self: """Load a dataset from a file. Args: path: Path to the file to load. fmt: Format of the file. If None, the format will be inferred from the file extension. Must be either 'yaml' or 'json'. custom_evaluator_types: Custom evaluator classes to use when deserializing the dataset. These are additional evaluators beyond the default ones. Returns: A new Dataset instance loaded from the file. Raises: ValidationError: If the file cannot be parsed as a valid dataset. ValueError: If the format cannot be inferred from the file extension. """ path = Path(path) fmt = cls._infer_fmt(path, fmt) raw = Path(path).read_text() try: return cls.from_text(raw, fmt=fmt, custom_evaluator_types=custom_evaluator_types) except ValidationError as e: # pragma: no cover raise ValueError(f'{path} contains data that does not match the schema for {cls.__name__}:\n{e}.') from e ``` #### from_text ```python from_text( contents: str, fmt: Literal["yaml", "json"] = "yaml", custom_evaluator_types: Sequence[ type[Evaluator[InputsT, OutputT, MetadataT]] ] = (), ) -> Self ``` Load a dataset from a string. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `contents` | `str` | The string content to parse. | *required* | | `fmt` | `Literal['yaml', 'json']` | Format of the content. Must be either 'yaml' or 'json'. | `'yaml'` | | `custom_evaluator_types` | `Sequence[type[Evaluator[InputsT, OutputT, MetadataT]]]` | Custom evaluator classes to use when deserializing the dataset. These are additional evaluators beyond the default ones. | `()` | Returns: | Type | Description | | --- | --- | | `Self` | A new Dataset instance parsed from the string. | Raises: | Type | Description | | --- | --- | | `ValidationError` | If the content cannot be parsed as a valid dataset. | Source code in `pydantic_evals/pydantic_evals/dataset.py` ```python @classmethod def from_text( cls, contents: str, fmt: Literal['yaml', 'json'] = 'yaml', custom_evaluator_types: Sequence[type[Evaluator[InputsT, OutputT, MetadataT]]] = (), ) -> Self: """Load a dataset from a string. Args: contents: The string content to parse. fmt: Format of the content. Must be either 'yaml' or 'json'. custom_evaluator_types: Custom evaluator classes to use when deserializing the dataset. These are additional evaluators beyond the default ones. Returns: A new Dataset instance parsed from the string. Raises: ValidationError: If the content cannot be parsed as a valid dataset. """ if fmt == 'yaml': loaded = yaml.safe_load(contents) return cls.from_dict(loaded, custom_evaluator_types) else: dataset_model_type = cls._serialization_type() dataset_model = dataset_model_type.model_validate_json(contents) return cls._from_dataset_model(dataset_model, custom_evaluator_types) ``` #### from_dict ```python from_dict( data: dict[str, Any], custom_evaluator_types: Sequence[ type[Evaluator[InputsT, OutputT, MetadataT]] ] = (), ) -> Self ``` Load a dataset from a dictionary. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `data` | `dict[str, Any]` | Dictionary representation of the dataset. | *required* | | `custom_evaluator_types` | `Sequence[type[Evaluator[InputsT, OutputT, MetadataT]]]` | Custom evaluator classes to use when deserializing the dataset. These are additional evaluators beyond the default ones. | `()` | Returns: | Type | Description | | --- | --- | | `Self` | A new Dataset instance created from the dictionary. | Raises: | Type | Description | | --- | --- | | `ValidationError` | If the dictionary cannot be converted to a valid dataset. | Source code in `pydantic_evals/pydantic_evals/dataset.py` ```python @classmethod def from_dict( cls, data: dict[str, Any], custom_evaluator_types: Sequence[type[Evaluator[InputsT, OutputT, MetadataT]]] = (), ) -> Self: """Load a dataset from a dictionary. Args: data: Dictionary representation of the dataset. custom_evaluator_types: Custom evaluator classes to use when deserializing the dataset. These are additional evaluators beyond the default ones. Returns: A new Dataset instance created from the dictionary. Raises: ValidationError: If the dictionary cannot be converted to a valid dataset. """ dataset_model_type = cls._serialization_type() dataset_model = dataset_model_type.model_validate(data) return cls._from_dataset_model(dataset_model, custom_evaluator_types) ``` #### to_file ```python to_file( path: Path | str, fmt: Literal["yaml", "json"] | None = None, schema_path: ( Path | str | None ) = DEFAULT_SCHEMA_PATH_TEMPLATE, custom_evaluator_types: Sequence[ type[Evaluator[InputsT, OutputT, MetadataT]] ] = (), ) ``` Save the dataset to a file. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `path` | `Path | str` | Path to save the dataset to. | *required* | | `fmt` | `Literal['yaml', 'json'] | None` | Format to use. If None, the format will be inferred from the file extension. Must be either 'yaml' or 'json'. | `None` | | `schema_path` | `Path | str | None` | Path to save the JSON schema to. If None, no schema will be saved. Can be a string template with {stem} which will be replaced with the dataset filename stem. | `DEFAULT_SCHEMA_PATH_TEMPLATE` | | `custom_evaluator_types` | `Sequence[type[Evaluator[InputsT, OutputT, MetadataT]]]` | Custom evaluator classes to include in the schema. | `()` | Source code in `pydantic_evals/pydantic_evals/dataset.py` ```python def to_file( self, path: Path | str, fmt: Literal['yaml', 'json'] | None = None, schema_path: Path | str | None = DEFAULT_SCHEMA_PATH_TEMPLATE, custom_evaluator_types: Sequence[type[Evaluator[InputsT, OutputT, MetadataT]]] = (), ): """Save the dataset to a file. Args: path: Path to save the dataset to. fmt: Format to use. If None, the format will be inferred from the file extension. Must be either 'yaml' or 'json'. schema_path: Path to save the JSON schema to. If None, no schema will be saved. Can be a string template with {stem} which will be replaced with the dataset filename stem. custom_evaluator_types: Custom evaluator classes to include in the schema. """ path = Path(path) fmt = self._infer_fmt(path, fmt) schema_ref: str | None = None if schema_path is not None: # pragma: no branch if isinstance(schema_path, str): # pragma: no branch schema_path = Path(schema_path.format(stem=path.stem)) if not schema_path.is_absolute(): schema_ref = str(schema_path) schema_path = path.parent / schema_path elif schema_path.is_relative_to(path): # pragma: no cover schema_ref = str(_get_relative_path_reference(schema_path, path)) else: # pragma: no cover schema_ref = str(schema_path) self._save_schema(schema_path, custom_evaluator_types) context: dict[str, Any] = {'use_short_form': True} if fmt == 'yaml': dumped_data = self.model_dump(mode='json', by_alias=True, exclude_defaults=True, context=context) content = yaml.dump(dumped_data, sort_keys=False) if schema_ref: # pragma: no branch yaml_language_server_line = f'{_YAML_SCHEMA_LINE_PREFIX}{schema_ref}' content = f'{yaml_language_server_line}\n{content}' path.write_text(content) else: context['$schema'] = schema_ref json_data = self.model_dump_json(indent=2, by_alias=True, exclude_defaults=True, context=context) path.write_text(json_data + '\n') ``` #### model_json_schema_with_evaluators ```python model_json_schema_with_evaluators( custom_evaluator_types: Sequence[ type[Evaluator[InputsT, OutputT, MetadataT]] ] = (), ) -> dict[str, Any] ``` Generate a JSON schema for this dataset type, including evaluator details. This is useful for generating a schema that can be used to validate YAML-format dataset files. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `custom_evaluator_types` | `Sequence[type[Evaluator[InputsT, OutputT, MetadataT]]]` | Custom evaluator classes to include in the schema. | `()` | Returns: | Type | Description | | --- | --- | | `dict[str, Any]` | A dictionary representing the JSON schema. | Source code in `pydantic_evals/pydantic_evals/dataset.py` ```python @classmethod def model_json_schema_with_evaluators( cls, custom_evaluator_types: Sequence[type[Evaluator[InputsT, OutputT, MetadataT]]] = (), ) -> dict[str, Any]: """Generate a JSON schema for this dataset type, including evaluator details. This is useful for generating a schema that can be used to validate YAML-format dataset files. Args: custom_evaluator_types: Custom evaluator classes to include in the schema. Returns: A dictionary representing the JSON schema. """ # Note: this function could maybe be simplified now that Evaluators are always dataclasses registry = _get_registry(custom_evaluator_types) evaluator_schema_types: list[Any] = [] for name, evaluator_class in registry.items(): type_hints = _typing_extra.get_function_type_hints(evaluator_class) type_hints.pop('return', None) required_type_hints: dict[str, Any] = {} for p in inspect.signature(evaluator_class).parameters.values(): type_hints.setdefault(p.name, Any) if p.default is not p.empty: type_hints[p.name] = NotRequired[type_hints[p.name]] else: required_type_hints[p.name] = type_hints[p.name] def _make_typed_dict(cls_name_prefix: str, fields: dict[str, Any]) -> Any: td = TypedDict(f'{cls_name_prefix}_{name}', fields) # pyright: ignore[reportArgumentType] config = ConfigDict(extra='forbid', arbitrary_types_allowed=True) # TODO: Replace with pydantic.with_config after pydantic 2.11 is released td.__pydantic_config__ = config # pyright: ignore[reportAttributeAccessIssue] return td # Shortest form: just the call name if len(type_hints) == 0 or not required_type_hints: evaluator_schema_types.append(Literal[name]) # Short form: can be called with only one parameter if len(type_hints) == 1: [type_hint_type] = type_hints.values() evaluator_schema_types.append(_make_typed_dict('short_evaluator', {name: type_hint_type})) elif len(required_type_hints) == 1: # pragma: no branch [type_hint_type] = required_type_hints.values() evaluator_schema_types.append(_make_typed_dict('short_evaluator', {name: type_hint_type})) # Long form: multiple parameters, possibly required if len(type_hints) > 1: params_td = _make_typed_dict('evaluator_params', type_hints) evaluator_schema_types.append(_make_typed_dict('evaluator', {name: params_td})) in_type, out_type, meta_type = cls._params() # Note: we shadow the `Case` and `Dataset` class names here to generate a clean JSON schema class Case(BaseModel, extra='forbid'): # pyright: ignore[reportUnusedClass] # this _is_ used below, but pyright doesn't seem to notice.. name: str | None = None inputs: in_type # pyright: ignore[reportInvalidTypeForm] metadata: meta_type | None = None # pyright: ignore[reportInvalidTypeForm,reportUnknownVariableType] expected_output: out_type | None = None # pyright: ignore[reportInvalidTypeForm,reportUnknownVariableType] if evaluator_schema_types: # pragma: no branch evaluators: list[Union[tuple(evaluator_schema_types)]] = [] # pyright: ignore # noqa UP007 class Dataset(BaseModel, extra='forbid'): cases: list[Case] if evaluator_schema_types: # pragma: no branch evaluators: list[Union[tuple(evaluator_schema_types)]] = [] # pyright: ignore # noqa UP007 json_schema = Dataset.model_json_schema() # See `_add_json_schema` below, since `$schema` is added to the JSON, it has to be supported in the JSON json_schema['properties']['$schema'] = {'type': 'string'} return json_schema ``` ### set_eval_attribute ```python set_eval_attribute(name: str, value: Any) -> None ``` Set an attribute on the current task run. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `name` | `str` | The name of the attribute. | *required* | | `value` | `Any` | The value of the attribute. | *required* | Source code in `pydantic_evals/pydantic_evals/dataset.py` ```python def set_eval_attribute(name: str, value: Any) -> None: """Set an attribute on the current task run. Args: name: The name of the attribute. value: The value of the attribute. """ current_case = _CURRENT_TASK_RUN.get() if current_case is not None: # pragma: no branch current_case.record_attribute(name, value) ``` ### increment_eval_metric ```python increment_eval_metric( name: str, amount: int | float ) -> None ``` Increment a metric on the current task run. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `name` | `str` | The name of the metric. | *required* | | `amount` | `int | float` | The amount to increment by. | *required* | Source code in `pydantic_evals/pydantic_evals/dataset.py` ```python def increment_eval_metric(name: str, amount: int | float) -> None: """Increment a metric on the current task run. Args: name: The name of the metric. amount: The amount to increment by. """ current_case = _CURRENT_TASK_RUN.get() if current_case is not None: # pragma: no branch current_case.increment_metric(name, amount) ``` # `pydantic_evals.evaluators` ### Contains Bases: `Evaluator[object, object, object]` Check if the output contains the expected output. For strings, checks if expected_output is a substring of output. For lists/tuples, checks if expected_output is in output. For dicts, checks if all key-value pairs in expected_output are in output. Note: case_sensitive only applies when both the value and output are strings. Source code in `pydantic_evals/pydantic_evals/evaluators/common.py` ```python @dataclass(repr=False) class Contains(Evaluator[object, object, object]): """Check if the output contains the expected output. For strings, checks if expected_output is a substring of output. For lists/tuples, checks if expected_output is in output. For dicts, checks if all key-value pairs in expected_output are in output. Note: case_sensitive only applies when both the value and output are strings. """ value: Any case_sensitive: bool = True as_strings: bool = False evaluation_name: str | None = field(default=None) def evaluate( self, ctx: EvaluatorContext[object, object, object], ) -> EvaluationReason: # Convert objects to strings if requested failure_reason: str | None = None as_strings = self.as_strings or (isinstance(self.value, str) and isinstance(ctx.output, str)) if as_strings: output_str = str(ctx.output) expected_str = str(self.value) if not self.case_sensitive: output_str = output_str.lower() expected_str = expected_str.lower() failure_reason: str | None = None if expected_str not in output_str: output_trunc = _truncated_repr(output_str, max_length=100) expected_trunc = _truncated_repr(expected_str, max_length=100) failure_reason = f'Output string {output_trunc} does not contain expected string {expected_trunc}' return EvaluationReason(value=failure_reason is None, reason=failure_reason) try: # Handle different collection types if isinstance(ctx.output, dict): if isinstance(self.value, dict): # Cast to Any to avoid type checking issues output_dict = cast(dict[Any, Any], ctx.output) # pyright: ignore[reportUnknownMemberType] expected_dict = cast(dict[Any, Any], self.value) # pyright: ignore[reportUnknownMemberType] for k in expected_dict: if k not in output_dict: k_trunc = _truncated_repr(k, max_length=30) failure_reason = f'Output dictionary does not contain expected key {k_trunc}' break elif output_dict[k] != expected_dict[k]: k_trunc = _truncated_repr(k, max_length=30) output_v_trunc = _truncated_repr(output_dict[k], max_length=100) expected_v_trunc = _truncated_repr(expected_dict[k], max_length=100) failure_reason = f'Output dictionary has different value for key {k_trunc}: {output_v_trunc} != {expected_v_trunc}' break else: if self.value not in ctx.output: # pyright: ignore[reportUnknownMemberType] output_trunc = _truncated_repr(ctx.output, max_length=200) # pyright: ignore[reportUnknownMemberType] failure_reason = f'Output {output_trunc} does not contain provided value as a key' elif self.value not in ctx.output: # pyright: ignore[reportOperatorIssue] # will be handled by except block output_trunc = _truncated_repr(ctx.output, max_length=200) failure_reason = f'Output {output_trunc} does not contain provided value' except (TypeError, ValueError) as e: failure_reason = f'Containment check failed: {e}' return EvaluationReason(value=failure_reason is None, reason=failure_reason) ``` ### Equals Bases: `Evaluator[object, object, object]` Check if the output exactly equals the provided value. Source code in `pydantic_evals/pydantic_evals/evaluators/common.py` ```python @dataclass(repr=False) class Equals(Evaluator[object, object, object]): """Check if the output exactly equals the provided value.""" value: Any evaluation_name: str | None = field(default=None) def evaluate(self, ctx: EvaluatorContext[object, object, object]) -> bool: return ctx.output == self.value ``` ### EqualsExpected Bases: `Evaluator[object, object, object]` Check if the output exactly equals the expected output. Source code in `pydantic_evals/pydantic_evals/evaluators/common.py` ```python @dataclass(repr=False) class EqualsExpected(Evaluator[object, object, object]): """Check if the output exactly equals the expected output.""" evaluation_name: str | None = field(default=None) def evaluate(self, ctx: EvaluatorContext[object, object, object]) -> bool | dict[str, bool]: if ctx.expected_output is None: return {} # Only compare if expected output is provided return ctx.output == ctx.expected_output ``` ### HasMatchingSpan Bases: `Evaluator[object, object, object]` Check if the span tree contains a span that matches the specified query. Source code in `pydantic_evals/pydantic_evals/evaluators/common.py` ```python @dataclass(repr=False) class HasMatchingSpan(Evaluator[object, object, object]): """Check if the span tree contains a span that matches the specified query.""" query: SpanQuery evaluation_name: str | None = field(default=None) def evaluate( self, ctx: EvaluatorContext[object, object, object], ) -> bool: return ctx.span_tree.any(self.query) ``` ### IsInstance Bases: `Evaluator[object, object, object]` Check if the output is an instance of a type with the given name. Source code in `pydantic_evals/pydantic_evals/evaluators/common.py` ```python @dataclass(repr=False) class IsInstance(Evaluator[object, object, object]): """Check if the output is an instance of a type with the given name.""" type_name: str evaluation_name: str | None = field(default=None) def evaluate(self, ctx: EvaluatorContext[object, object, object]) -> EvaluationReason: output = ctx.output for cls in type(output).__mro__: if cls.__name__ == self.type_name or cls.__qualname__ == self.type_name: return EvaluationReason(value=True) reason = f'output is of type {type(output).__name__}' if type(output).__qualname__ != type(output).__name__: reason += f' (qualname: {type(output).__qualname__})' return EvaluationReason(value=False, reason=reason) ``` ### LLMJudge Bases: `Evaluator[object, object, object]` Judge whether the output of a language model meets the criteria of a provided rubric. If you do not specify a model, it uses the default model for judging. This starts as 'openai:gpt-4o', but can be overridden by calling set_default_judge_model. Source code in `pydantic_evals/pydantic_evals/evaluators/common.py` ```python @dataclass(repr=False) class LLMJudge(Evaluator[object, object, object]): """Judge whether the output of a language model meets the criteria of a provided rubric. If you do not specify a model, it uses the default model for judging. This starts as 'openai:gpt-4o', but can be overridden by calling [`set_default_judge_model`][pydantic_evals.evaluators.llm_as_a_judge.set_default_judge_model]. """ rubric: str model: models.Model | models.KnownModelName | None = None include_input: bool = False include_expected_output: bool = False model_settings: ModelSettings | None = None score: OutputConfig | Literal[False] = False assertion: OutputConfig | Literal[False] = field(default_factory=lambda: OutputConfig(include_reason=True)) async def evaluate( self, ctx: EvaluatorContext[object, object, object], ) -> EvaluatorOutput: if self.include_input: if self.include_expected_output: from .llm_as_a_judge import judge_input_output_expected grading_output = await judge_input_output_expected( ctx.inputs, ctx.output, ctx.expected_output, self.rubric, self.model, self.model_settings ) else: from .llm_as_a_judge import judge_input_output grading_output = await judge_input_output( ctx.inputs, ctx.output, self.rubric, self.model, self.model_settings ) else: if self.include_expected_output: from .llm_as_a_judge import judge_output_expected grading_output = await judge_output_expected( ctx.output, ctx.expected_output, self.rubric, self.model, self.model_settings ) else: from .llm_as_a_judge import judge_output grading_output = await judge_output(ctx.output, self.rubric, self.model, self.model_settings) output: dict[str, EvaluationScalar | EvaluationReason] = {} include_both = self.score is not False and self.assertion is not False evaluation_name = self.get_default_evaluation_name() if self.score is not False: default_name = f'{evaluation_name}_score' if include_both else evaluation_name _update_combined_output(output, grading_output.score, grading_output.reason, self.score, default_name) if self.assertion is not False: default_name = f'{evaluation_name}_pass' if include_both else evaluation_name _update_combined_output(output, grading_output.pass_, grading_output.reason, self.assertion, default_name) return output def build_serialization_arguments(self): result = super().build_serialization_arguments() # always serialize the model as a string when present; use its name if it's a KnownModelName if (model := result.get('model')) and isinstance(model, models.Model): # pragma: no branch result['model'] = f'{model.system}:{model.model_name}' # Note: this may lead to confusion if you try to serialize-then-deserialize with a custom model. # I expect that is rare enough to be worth not solving yet, but common enough that we probably will want to # solve it eventually. I'm imagining some kind of model registry, but don't want to work out the details yet. return result ``` ### MaxDuration Bases: `Evaluator[object, object, object]` Check if the execution time is under the specified maximum. Source code in `pydantic_evals/pydantic_evals/evaluators/common.py` ```python @dataclass(repr=False) class MaxDuration(Evaluator[object, object, object]): """Check if the execution time is under the specified maximum.""" seconds: float | timedelta def evaluate(self, ctx: EvaluatorContext[object, object, object]) -> bool: duration = timedelta(seconds=ctx.duration) seconds = self.seconds if not isinstance(seconds, timedelta): seconds = timedelta(seconds=seconds) return duration <= seconds ``` ### OutputConfig Bases: `TypedDict` Configuration for the score and assertion outputs of the LLMJudge evaluator. Source code in `pydantic_evals/pydantic_evals/evaluators/common.py` ```python class OutputConfig(TypedDict, total=False): """Configuration for the score and assertion outputs of the LLMJudge evaluator.""" evaluation_name: str include_reason: bool ``` ### Python Bases: `Evaluator[object, object, object]` The output of this evaluator is the result of evaluating the provided Python expression. ***WARNING***: this evaluator runs arbitrary Python code, so you should ***NEVER*** use it with untrusted inputs. Source code in `pydantic_evals/pydantic_evals/evaluators/common.py` ```python @dataclass(repr=False) class Python(Evaluator[object, object, object]): """The output of this evaluator is the result of evaluating the provided Python expression. ***WARNING***: this evaluator runs arbitrary Python code, so you should ***NEVER*** use it with untrusted inputs. """ expression: str evaluation_name: str | None = field(default=None) def evaluate(self, ctx: EvaluatorContext[object, object, object]) -> EvaluatorOutput: # Evaluate the condition, exposing access to the evaluator context as `ctx`. return eval(self.expression, {'ctx': ctx}) ``` ### EvaluatorContext Bases: `Generic[InputsT, OutputT, MetadataT]` Context for evaluating a task execution. An instance of this class is the sole input to all Evaluators. It contains all the information needed to evaluate the task execution, including inputs, outputs, metadata, and telemetry data. Evaluators use this context to access the task inputs, actual output, expected output, and other information when evaluating the result of the task execution. Example: ```python from dataclasses import dataclass from pydantic_evals.evaluators import Evaluator, EvaluatorContext @dataclass class ExactMatch(Evaluator): def evaluate(self, ctx: EvaluatorContext) -> bool: # Use the context to access task inputs, outputs, and expected outputs return ctx.output == ctx.expected_output ``` Source code in `pydantic_evals/pydantic_evals/evaluators/context.py` ````python @dataclass class EvaluatorContext(Generic[InputsT, OutputT, MetadataT]): """Context for evaluating a task execution. An instance of this class is the sole input to all Evaluators. It contains all the information needed to evaluate the task execution, including inputs, outputs, metadata, and telemetry data. Evaluators use this context to access the task inputs, actual output, expected output, and other information when evaluating the result of the task execution. Example: ```python from dataclasses import dataclass from pydantic_evals.evaluators import Evaluator, EvaluatorContext @dataclass class ExactMatch(Evaluator): def evaluate(self, ctx: EvaluatorContext) -> bool: # Use the context to access task inputs, outputs, and expected outputs return ctx.output == ctx.expected_output ``` """ name: str | None """The name of the case.""" inputs: InputsT """The inputs provided to the task for this case.""" metadata: MetadataT | None """Metadata associated with the case, if provided. May be None if no metadata was specified.""" expected_output: OutputT | None """The expected output for the case, if provided. May be None if no expected output was specified.""" output: OutputT """The actual output produced by the task for this case.""" duration: float """The duration of the task run for this case.""" _span_tree: SpanTree | SpanTreeRecordingError = field(repr=False) """The span tree for the task run for this case. This will be `None` if `logfire.configure` has not been called. """ attributes: dict[str, Any] """Attributes associated with the task run for this case. These can be set by calling `pydantic_evals.dataset.set_eval_attribute` in any code executed during the evaluation task.""" metrics: dict[str, int | float] """Metrics associated with the task run for this case. These can be set by calling `pydantic_evals.dataset.increment_eval_metric` in any code executed during the evaluation task.""" @property def span_tree(self) -> SpanTree: """Get the `SpanTree` for this task execution. The span tree is a graph where each node corresponds to an OpenTelemetry span recorded during the task execution, including timing information and any custom spans created during execution. Returns: The span tree for the task execution. Raises: SpanTreeRecordingError: If spans were not captured during execution of the task, e.g. due to not having the necessary dependencies installed. """ if isinstance(self._span_tree, SpanTreeRecordingError): # In this case, there was a reason we couldn't record the SpanTree. We raise that now raise self._span_tree return self._span_tree ```` #### name ```python name: str | None ``` The name of the case. #### inputs ```python inputs: InputsT ``` The inputs provided to the task for this case. #### metadata ```python metadata: MetadataT | None ``` Metadata associated with the case, if provided. May be None if no metadata was specified. #### expected_output ```python expected_output: OutputT | None ``` The expected output for the case, if provided. May be None if no expected output was specified. #### output ```python output: OutputT ``` The actual output produced by the task for this case. #### duration ```python duration: float ``` The duration of the task run for this case. #### attributes ```python attributes: dict[str, Any] ``` Attributes associated with the task run for this case. These can be set by calling `pydantic_evals.dataset.set_eval_attribute` in any code executed during the evaluation task. #### metrics ```python metrics: dict[str, int | float] ``` Metrics associated with the task run for this case. These can be set by calling `pydantic_evals.dataset.increment_eval_metric` in any code executed during the evaluation task. #### span_tree ```python span_tree: SpanTree ``` Get the `SpanTree` for this task execution. The span tree is a graph where each node corresponds to an OpenTelemetry span recorded during the task execution, including timing information and any custom spans created during execution. Returns: | Type | Description | | --- | --- | | `SpanTree` | The span tree for the task execution. | Raises: | Type | Description | | --- | --- | | `SpanTreeRecordingError` | If spans were not captured during execution of the task, e.g. due to not having the necessary dependencies installed. | ### EvaluationReason The result of running an evaluator with an optional explanation. Contains a scalar value and an optional "reason" explaining the value. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `value` | `EvaluationScalar` | The scalar result of the evaluation (boolean, integer, float, or string). | *required* | | `reason` | `str | None` | An optional explanation of the evaluation result. | `None` | Source code in `pydantic_evals/pydantic_evals/evaluators/evaluator.py` ```python @dataclass class EvaluationReason: """The result of running an evaluator with an optional explanation. Contains a scalar value and an optional "reason" explaining the value. Args: value: The scalar result of the evaluation (boolean, integer, float, or string). reason: An optional explanation of the evaluation result. """ value: EvaluationScalar reason: str | None = None ``` ### EvaluationResult Bases: `Generic[EvaluationScalarT]` The details of an individual evaluation result. Contains the name, value, reason, and source evaluator for a single evaluation. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `name` | `str` | The name of the evaluation. | *required* | | `value` | `EvaluationScalarT` | The scalar result of the evaluation. | *required* | | `reason` | `str | None` | An optional explanation of the evaluation result. | *required* | | `source` | `Evaluator` | The evaluator that produced this result. | *required* | Source code in `pydantic_evals/pydantic_evals/evaluators/evaluator.py` ```python @dataclass class EvaluationResult(Generic[EvaluationScalarT]): """The details of an individual evaluation result. Contains the name, value, reason, and source evaluator for a single evaluation. Args: name: The name of the evaluation. value: The scalar result of the evaluation. reason: An optional explanation of the evaluation result. source: The evaluator that produced this result. """ name: str value: EvaluationScalarT reason: str | None source: Evaluator def downcast(self, *value_types: type[T]) -> EvaluationResult[T] | None: """Attempt to downcast this result to a more specific type. Args: *value_types: The types to check the value against. Returns: A downcast version of this result if the value is an instance of one of the given types, otherwise None. """ # Check if value matches any of the target types, handling bool as a special case for value_type in value_types: if isinstance(self.value, value_type): # Only match bool with explicit bool type if isinstance(self.value, bool) and value_type is not bool: continue return cast(EvaluationResult[T], self) return None ``` #### downcast ```python downcast( *value_types: type[T], ) -> EvaluationResult[T] | None ``` Attempt to downcast this result to a more specific type. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `*value_types` | `type[T]` | The types to check the value against. | `()` | Returns: | Type | Description | | --- | --- | | `EvaluationResult[T] | None` | A downcast version of this result if the value is an instance of one of the given types, | | `EvaluationResult[T] | None` | otherwise None. | Source code in `pydantic_evals/pydantic_evals/evaluators/evaluator.py` ```python def downcast(self, *value_types: type[T]) -> EvaluationResult[T] | None: """Attempt to downcast this result to a more specific type. Args: *value_types: The types to check the value against. Returns: A downcast version of this result if the value is an instance of one of the given types, otherwise None. """ # Check if value matches any of the target types, handling bool as a special case for value_type in value_types: if isinstance(self.value, value_type): # Only match bool with explicit bool type if isinstance(self.value, bool) and value_type is not bool: continue return cast(EvaluationResult[T], self) return None ``` ### Evaluator Bases: `Generic[InputsT, OutputT, MetadataT]` Base class for all evaluators. Evaluators can assess the performance of a task in a variety of ways, as a function of the EvaluatorContext. Subclasses must implement the `evaluate` method. Note it can be defined with either `def` or `async def`. Example: ```python from dataclasses import dataclass from pydantic_evals.evaluators import Evaluator, EvaluatorContext @dataclass class ExactMatch(Evaluator): def evaluate(self, ctx: EvaluatorContext) -> bool: return ctx.output == ctx.expected_output ``` Source code in `pydantic_evals/pydantic_evals/evaluators/evaluator.py` ````python @dataclass(repr=False) class Evaluator(Generic[InputsT, OutputT, MetadataT], metaclass=_StrictABCMeta): """Base class for all evaluators. Evaluators can assess the performance of a task in a variety of ways, as a function of the EvaluatorContext. Subclasses must implement the `evaluate` method. Note it can be defined with either `def` or `async def`. Example: ```python from dataclasses import dataclass from pydantic_evals.evaluators import Evaluator, EvaluatorContext @dataclass class ExactMatch(Evaluator): def evaluate(self, ctx: EvaluatorContext) -> bool: return ctx.output == ctx.expected_output ``` """ __pydantic_config__ = ConfigDict(arbitrary_types_allowed=True) @classmethod def get_serialization_name(cls) -> str: """Return the 'name' of this Evaluator to use during serialization. Returns: The name of the Evaluator, which is typically the class name. """ return cls.__name__ @classmethod @deprecated('`name` has been renamed, use `get_serialization_name` instead.') def name(cls) -> str: """`name` has been renamed, use `get_serialization_name` instead.""" return cls.get_serialization_name() def get_default_evaluation_name(self) -> str: """Return the default name to use in reports for the output of this evaluator. By default, if the evaluator has an attribute called `evaluation_name` of type string, that will be used. Otherwise, the serialization name of the evaluator (which is usually the class name) will be used. This can be overridden to get a more descriptive name in evaluation reports, e.g. using instance information. Note that evaluators that return a mapping of results will always use the keys of that mapping as the names of the associated evaluation results. """ evaluation_name = getattr(self, 'evaluation_name', None) if isinstance(evaluation_name, str): # If the evaluator has an attribute `name` of type string, use that return evaluation_name return self.get_serialization_name() @abstractmethod def evaluate( self, ctx: EvaluatorContext[InputsT, OutputT, MetadataT] ) -> EvaluatorOutput | Awaitable[EvaluatorOutput]: # pragma: no cover """Evaluate the task output in the given context. This is the main evaluation method that subclasses must implement. It can be either synchronous or asynchronous, returning either an EvaluatorOutput directly or an Awaitable[EvaluatorOutput]. Args: ctx: The context containing the inputs, outputs, and metadata for evaluation. Returns: The evaluation result, which can be a scalar value, an EvaluationReason, or a mapping of evaluation names to either of those. Can be returned either synchronously or as an awaitable for asynchronous evaluation. """ raise NotImplementedError('You must implement `evaluate`.') def evaluate_sync(self, ctx: EvaluatorContext[InputsT, OutputT, MetadataT]) -> EvaluatorOutput: """Run the evaluator synchronously, handling both sync and async implementations. This method ensures synchronous execution by running any async evaluate implementation to completion using run_until_complete. Args: ctx: The context containing the inputs, outputs, and metadata for evaluation. Returns: The evaluation result, which can be a scalar value, an EvaluationReason, or a mapping of evaluation names to either of those. """ output = self.evaluate(ctx) if inspect.iscoroutine(output): # pragma: no cover return get_event_loop().run_until_complete(output) else: return cast(EvaluatorOutput, output) async def evaluate_async(self, ctx: EvaluatorContext[InputsT, OutputT, MetadataT]) -> EvaluatorOutput: """Run the evaluator asynchronously, handling both sync and async implementations. This method ensures asynchronous execution by properly awaiting any async evaluate implementation. For synchronous implementations, it returns the result directly. Args: ctx: The context containing the inputs, outputs, and metadata for evaluation. Returns: The evaluation result, which can be a scalar value, an EvaluationReason, or a mapping of evaluation names to either of those. """ # Note: If self.evaluate is synchronous, but you need to prevent this from blocking, override this method with: # return await anyio.to_thread.run_sync(self.evaluate, ctx) output = self.evaluate(ctx) if inspect.iscoroutine(output): return await output else: return cast(EvaluatorOutput, output) @model_serializer(mode='plain') def serialize(self, info: SerializationInfo) -> Any: """Serialize this Evaluator to a JSON-serializable form. Returns: A JSON-serializable representation of this evaluator as an EvaluatorSpec. """ raw_arguments = self.build_serialization_arguments() arguments: None | tuple[Any,] | dict[str, Any] if len(raw_arguments) == 0: arguments = None elif len(raw_arguments) == 1: arguments = (next(iter(raw_arguments.values())),) else: arguments = raw_arguments return to_jsonable_python( EvaluatorSpec(name=self.get_serialization_name(), arguments=arguments), context=info.context, serialize_unknown=True, ) def build_serialization_arguments(self) -> dict[str, Any]: """Build the arguments for serialization. Evaluators are serialized for inclusion as the "source" in an `EvaluationResult`. If you want to modify how the evaluator is serialized for that or other purposes, you can override this method. Returns: A dictionary of arguments to be used during serialization. """ raw_arguments: dict[str, Any] = {} for field in fields(self): value = getattr(self, field.name) # always exclude defaults: if field.default is not MISSING: if value == field.default: continue if field.default_factory is not MISSING: if value == field.default_factory(): # pragma: no branch continue raw_arguments[field.name] = value return raw_arguments __repr__ = _utils.dataclasses_no_defaults_repr ```` #### get_serialization_name ```python get_serialization_name() -> str ``` Return the 'name' of this Evaluator to use during serialization. Returns: | Type | Description | | --- | --- | | `str` | The name of the Evaluator, which is typically the class name. | Source code in `pydantic_evals/pydantic_evals/evaluators/evaluator.py` ```python @classmethod def get_serialization_name(cls) -> str: """Return the 'name' of this Evaluator to use during serialization. Returns: The name of the Evaluator, which is typically the class name. """ return cls.__name__ ``` #### name ```python name() -> str ``` `name` has been renamed, use `get_serialization_name` instead. Source code in `pydantic_evals/pydantic_evals/evaluators/evaluator.py` ```python @classmethod @deprecated('`name` has been renamed, use `get_serialization_name` instead.') def name(cls) -> str: """`name` has been renamed, use `get_serialization_name` instead.""" return cls.get_serialization_name() ``` #### get_default_evaluation_name ```python get_default_evaluation_name() -> str ``` Return the default name to use in reports for the output of this evaluator. By default, if the evaluator has an attribute called `evaluation_name` of type string, that will be used. Otherwise, the serialization name of the evaluator (which is usually the class name) will be used. This can be overridden to get a more descriptive name in evaluation reports, e.g. using instance information. Note that evaluators that return a mapping of results will always use the keys of that mapping as the names of the associated evaluation results. Source code in `pydantic_evals/pydantic_evals/evaluators/evaluator.py` ```python def get_default_evaluation_name(self) -> str: """Return the default name to use in reports for the output of this evaluator. By default, if the evaluator has an attribute called `evaluation_name` of type string, that will be used. Otherwise, the serialization name of the evaluator (which is usually the class name) will be used. This can be overridden to get a more descriptive name in evaluation reports, e.g. using instance information. Note that evaluators that return a mapping of results will always use the keys of that mapping as the names of the associated evaluation results. """ evaluation_name = getattr(self, 'evaluation_name', None) if isinstance(evaluation_name, str): # If the evaluator has an attribute `name` of type string, use that return evaluation_name return self.get_serialization_name() ``` #### evaluate ```python evaluate( ctx: EvaluatorContext[InputsT, OutputT, MetadataT], ) -> EvaluatorOutput | Awaitable[EvaluatorOutput] ``` Evaluate the task output in the given context. This is the main evaluation method that subclasses must implement. It can be either synchronous or asynchronous, returning either an EvaluatorOutput directly or an Awaitable[EvaluatorOutput]. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `ctx` | `EvaluatorContext[InputsT, OutputT, MetadataT]` | The context containing the inputs, outputs, and metadata for evaluation. | *required* | Returns: | Type | Description | | --- | --- | | `EvaluatorOutput | Awaitable[EvaluatorOutput]` | The evaluation result, which can be a scalar value, an EvaluationReason, or a mapping | | `EvaluatorOutput | Awaitable[EvaluatorOutput]` | of evaluation names to either of those. Can be returned either synchronously or as an | | `EvaluatorOutput | Awaitable[EvaluatorOutput]` | awaitable for asynchronous evaluation. | Source code in `pydantic_evals/pydantic_evals/evaluators/evaluator.py` ```python @abstractmethod def evaluate( self, ctx: EvaluatorContext[InputsT, OutputT, MetadataT] ) -> EvaluatorOutput | Awaitable[EvaluatorOutput]: # pragma: no cover """Evaluate the task output in the given context. This is the main evaluation method that subclasses must implement. It can be either synchronous or asynchronous, returning either an EvaluatorOutput directly or an Awaitable[EvaluatorOutput]. Args: ctx: The context containing the inputs, outputs, and metadata for evaluation. Returns: The evaluation result, which can be a scalar value, an EvaluationReason, or a mapping of evaluation names to either of those. Can be returned either synchronously or as an awaitable for asynchronous evaluation. """ raise NotImplementedError('You must implement `evaluate`.') ``` #### evaluate_sync ```python evaluate_sync( ctx: EvaluatorContext[InputsT, OutputT, MetadataT], ) -> EvaluatorOutput ``` Run the evaluator synchronously, handling both sync and async implementations. This method ensures synchronous execution by running any async evaluate implementation to completion using run_until_complete. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `ctx` | `EvaluatorContext[InputsT, OutputT, MetadataT]` | The context containing the inputs, outputs, and metadata for evaluation. | *required* | Returns: | Type | Description | | --- | --- | | `EvaluatorOutput` | The evaluation result, which can be a scalar value, an EvaluationReason, or a mapping | | `EvaluatorOutput` | of evaluation names to either of those. | Source code in `pydantic_evals/pydantic_evals/evaluators/evaluator.py` ```python def evaluate_sync(self, ctx: EvaluatorContext[InputsT, OutputT, MetadataT]) -> EvaluatorOutput: """Run the evaluator synchronously, handling both sync and async implementations. This method ensures synchronous execution by running any async evaluate implementation to completion using run_until_complete. Args: ctx: The context containing the inputs, outputs, and metadata for evaluation. Returns: The evaluation result, which can be a scalar value, an EvaluationReason, or a mapping of evaluation names to either of those. """ output = self.evaluate(ctx) if inspect.iscoroutine(output): # pragma: no cover return get_event_loop().run_until_complete(output) else: return cast(EvaluatorOutput, output) ``` #### evaluate_async ```python evaluate_async( ctx: EvaluatorContext[InputsT, OutputT, MetadataT], ) -> EvaluatorOutput ``` Run the evaluator asynchronously, handling both sync and async implementations. This method ensures asynchronous execution by properly awaiting any async evaluate implementation. For synchronous implementations, it returns the result directly. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `ctx` | `EvaluatorContext[InputsT, OutputT, MetadataT]` | The context containing the inputs, outputs, and metadata for evaluation. | *required* | Returns: | Type | Description | | --- | --- | | `EvaluatorOutput` | The evaluation result, which can be a scalar value, an EvaluationReason, or a mapping | | `EvaluatorOutput` | of evaluation names to either of those. | Source code in `pydantic_evals/pydantic_evals/evaluators/evaluator.py` ```python async def evaluate_async(self, ctx: EvaluatorContext[InputsT, OutputT, MetadataT]) -> EvaluatorOutput: """Run the evaluator asynchronously, handling both sync and async implementations. This method ensures asynchronous execution by properly awaiting any async evaluate implementation. For synchronous implementations, it returns the result directly. Args: ctx: The context containing the inputs, outputs, and metadata for evaluation. Returns: The evaluation result, which can be a scalar value, an EvaluationReason, or a mapping of evaluation names to either of those. """ # Note: If self.evaluate is synchronous, but you need to prevent this from blocking, override this method with: # return await anyio.to_thread.run_sync(self.evaluate, ctx) output = self.evaluate(ctx) if inspect.iscoroutine(output): return await output else: return cast(EvaluatorOutput, output) ``` #### serialize ```python serialize(info: SerializationInfo) -> Any ``` Serialize this Evaluator to a JSON-serializable form. Returns: | Type | Description | | --- | --- | | `Any` | A JSON-serializable representation of this evaluator as an EvaluatorSpec. | Source code in `pydantic_evals/pydantic_evals/evaluators/evaluator.py` ```python @model_serializer(mode='plain') def serialize(self, info: SerializationInfo) -> Any: """Serialize this Evaluator to a JSON-serializable form. Returns: A JSON-serializable representation of this evaluator as an EvaluatorSpec. """ raw_arguments = self.build_serialization_arguments() arguments: None | tuple[Any,] | dict[str, Any] if len(raw_arguments) == 0: arguments = None elif len(raw_arguments) == 1: arguments = (next(iter(raw_arguments.values())),) else: arguments = raw_arguments return to_jsonable_python( EvaluatorSpec(name=self.get_serialization_name(), arguments=arguments), context=info.context, serialize_unknown=True, ) ``` #### build_serialization_arguments ```python build_serialization_arguments() -> dict[str, Any] ``` Build the arguments for serialization. Evaluators are serialized for inclusion as the "source" in an `EvaluationResult`. If you want to modify how the evaluator is serialized for that or other purposes, you can override this method. Returns: | Type | Description | | --- | --- | | `dict[str, Any]` | A dictionary of arguments to be used during serialization. | Source code in `pydantic_evals/pydantic_evals/evaluators/evaluator.py` ```python def build_serialization_arguments(self) -> dict[str, Any]: """Build the arguments for serialization. Evaluators are serialized for inclusion as the "source" in an `EvaluationResult`. If you want to modify how the evaluator is serialized for that or other purposes, you can override this method. Returns: A dictionary of arguments to be used during serialization. """ raw_arguments: dict[str, Any] = {} for field in fields(self): value = getattr(self, field.name) # always exclude defaults: if field.default is not MISSING: if value == field.default: continue if field.default_factory is not MISSING: if value == field.default_factory(): # pragma: no branch continue raw_arguments[field.name] = value return raw_arguments ``` ### EvaluatorOutput ```python EvaluatorOutput = Union[ EvaluationScalar, EvaluationReason, Mapping[str, Union[EvaluationScalar, EvaluationReason]], ] ``` Type for the output of an evaluator, which can be a scalar, an EvaluationReason, or a mapping of names to either. ### GradingOutput Bases: `BaseModel` The output of a grading operation. Source code in `pydantic_evals/pydantic_evals/evaluators/llm_as_a_judge.py` ```python class GradingOutput(BaseModel, populate_by_name=True): """The output of a grading operation.""" reason: str pass_: bool = Field(validation_alias='pass', serialization_alias='pass') score: float ``` ### judge_output ```python judge_output( output: Any, rubric: str, model: Model | KnownModelName | None = None, model_settings: ModelSettings | None = None, ) -> GradingOutput ``` Judge the output of a model based on a rubric. If the model is not specified, a default model is used. The default model starts as 'openai:gpt-4o', but this can be changed using the `set_default_judge_model` function. Source code in `pydantic_evals/pydantic_evals/evaluators/llm_as_a_judge.py` ```python async def judge_output( output: Any, rubric: str, model: models.Model | models.KnownModelName | None = None, model_settings: ModelSettings | None = None, ) -> GradingOutput: """Judge the output of a model based on a rubric. If the model is not specified, a default model is used. The default model starts as 'openai:gpt-4o', but this can be changed using the `set_default_judge_model` function. """ user_prompt = dedent( f""" {_stringify(output)} {rubric} """ ) return ( await _judge_output_agent.run(user_prompt, model=model or _default_model, model_settings=model_settings) ).output ``` ### judge_input_output ```python judge_input_output( inputs: Any, output: Any, rubric: str, model: Model | KnownModelName | None = None, model_settings: ModelSettings | None = None, ) -> GradingOutput ``` Judge the output of a model based on the inputs and a rubric. If the model is not specified, a default model is used. The default model starts as 'openai:gpt-4o', but this can be changed using the `set_default_judge_model` function. Source code in `pydantic_evals/pydantic_evals/evaluators/llm_as_a_judge.py` ```python async def judge_input_output( inputs: Any, output: Any, rubric: str, model: models.Model | models.KnownModelName | None = None, model_settings: ModelSettings | None = None, ) -> GradingOutput: """Judge the output of a model based on the inputs and a rubric. If the model is not specified, a default model is used. The default model starts as 'openai:gpt-4o', but this can be changed using the `set_default_judge_model` function. """ user_prompt = dedent( f""" {_stringify(inputs)} {_stringify(output)} {rubric} """ ) return ( await _judge_input_output_agent.run(user_prompt, model=model or _default_model, model_settings=model_settings) ).output ``` ### judge_input_output_expected ```python judge_input_output_expected( inputs: Any, output: Any, expected_output: Any, rubric: str, model: Model | KnownModelName | None = None, model_settings: ModelSettings | None = None, ) -> GradingOutput ``` Judge the output of a model based on the inputs and a rubric. If the model is not specified, a default model is used. The default model starts as 'openai:gpt-4o', but this can be changed using the `set_default_judge_model` function. Source code in `pydantic_evals/pydantic_evals/evaluators/llm_as_a_judge.py` ```python async def judge_input_output_expected( inputs: Any, output: Any, expected_output: Any, rubric: str, model: models.Model | models.KnownModelName | None = None, model_settings: ModelSettings | None = None, ) -> GradingOutput: """Judge the output of a model based on the inputs and a rubric. If the model is not specified, a default model is used. The default model starts as 'openai:gpt-4o', but this can be changed using the `set_default_judge_model` function. """ user_prompt = dedent( f""" {_stringify(inputs)} {_stringify(expected_output)} {_stringify(output)} {rubric} """ ) return ( await _judge_input_output_expected_agent.run( user_prompt, model=model or _default_model, model_settings=model_settings ) ).output ``` ### judge_output_expected ```python judge_output_expected( output: Any, expected_output: Any, rubric: str, model: Model | KnownModelName | None = None, model_settings: ModelSettings | None = None, ) -> GradingOutput ``` Judge the output of a model based on the expected output, output, and a rubric. If the model is not specified, a default model is used. The default model starts as 'openai:gpt-4o', but this can be changed using the `set_default_judge_model` function. Source code in `pydantic_evals/pydantic_evals/evaluators/llm_as_a_judge.py` ```python async def judge_output_expected( output: Any, expected_output: Any, rubric: str, model: models.Model | models.KnownModelName | None = None, model_settings: ModelSettings | None = None, ) -> GradingOutput: """Judge the output of a model based on the expected output, output, and a rubric. If the model is not specified, a default model is used. The default model starts as 'openai:gpt-4o', but this can be changed using the `set_default_judge_model` function. """ user_prompt = dedent( f""" {_stringify(expected_output)} {_stringify(output)} {rubric} """ ) return ( await _judge_output_expected_agent.run( user_prompt, model=model or _default_model, model_settings=model_settings ) ).output ``` ### set_default_judge_model ```python set_default_judge_model( model: Model | KnownModelName, ) -> None ``` Set the default model used for judging. This model is used if `None` is passed to the `model` argument of `judge_output` and `judge_input_output`. Source code in `pydantic_evals/pydantic_evals/evaluators/llm_as_a_judge.py` ```python def set_default_judge_model(model: models.Model | models.KnownModelName) -> None: # pragma: no cover """Set the default model used for judging. This model is used if `None` is passed to the `model` argument of `judge_output` and `judge_input_output`. """ global _default_model _default_model = model ``` # `pydantic_evals.generation` Utilities for generating example datasets for pydantic_evals. This module provides functions for generating sample datasets for testing and examples, using LLMs to create realistic test data with proper structure. ### generate_dataset ```python generate_dataset( *, dataset_type: type[ Dataset[InputsT, OutputT, MetadataT] ], path: Path | str | None = None, custom_evaluator_types: Sequence[ type[Evaluator[InputsT, OutputT, MetadataT]] ] = (), model: Model | KnownModelName = "openai:gpt-4o", n_examples: int = 3, extra_instructions: str | None = None ) -> Dataset[InputsT, OutputT, MetadataT] ``` Use an LLM to generate a dataset of test cases, each consisting of input, expected output, and metadata. This function creates a properly structured dataset with the specified input, output, and metadata types. It uses an LLM to attempt to generate realistic test cases that conform to the types' schemas. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `path` | `Path | str | None` | Optional path to save the generated dataset. If provided, the dataset will be saved to this location. | `None` | | `dataset_type` | `type[Dataset[InputsT, OutputT, MetadataT]]` | The type of dataset to generate, with the desired input, output, and metadata types. | *required* | | `custom_evaluator_types` | `Sequence[type[Evaluator[InputsT, OutputT, MetadataT]]]` | Optional sequence of custom evaluator classes to include in the schema. | `()` | | `model` | `Model | KnownModelName` | The PydanticAI model to use for generation. Defaults to 'gpt-4o'. | `'openai:gpt-4o'` | | `n_examples` | `int` | Number of examples to generate. Defaults to 3. | `3` | | `extra_instructions` | `str | None` | Optional additional instructions to provide to the LLM. | `None` | Returns: | Type | Description | | --- | --- | | `Dataset[InputsT, OutputT, MetadataT]` | A properly structured Dataset object with generated test cases. | Raises: | Type | Description | | --- | --- | | `ValidationError` | If the LLM's response cannot be parsed as a valid dataset. | Source code in `pydantic_evals/pydantic_evals/generation.py` ```python async def generate_dataset( *, dataset_type: type[Dataset[InputsT, OutputT, MetadataT]], path: Path | str | None = None, custom_evaluator_types: Sequence[type[Evaluator[InputsT, OutputT, MetadataT]]] = (), model: models.Model | models.KnownModelName = 'openai:gpt-4o', n_examples: int = 3, extra_instructions: str | None = None, ) -> Dataset[InputsT, OutputT, MetadataT]: """Use an LLM to generate a dataset of test cases, each consisting of input, expected output, and metadata. This function creates a properly structured dataset with the specified input, output, and metadata types. It uses an LLM to attempt to generate realistic test cases that conform to the types' schemas. Args: path: Optional path to save the generated dataset. If provided, the dataset will be saved to this location. dataset_type: The type of dataset to generate, with the desired input, output, and metadata types. custom_evaluator_types: Optional sequence of custom evaluator classes to include in the schema. model: The PydanticAI model to use for generation. Defaults to 'gpt-4o'. n_examples: Number of examples to generate. Defaults to 3. extra_instructions: Optional additional instructions to provide to the LLM. Returns: A properly structured Dataset object with generated test cases. Raises: ValidationError: If the LLM's response cannot be parsed as a valid dataset. """ output_schema = dataset_type.model_json_schema_with_evaluators(custom_evaluator_types) # TODO(DavidM): Update this once we add better response_format and/or ResultTool support to PydanticAI agent = Agent( model, system_prompt=( f'Generate an object that is in compliance with this JSON schema:\n{output_schema}\n\n' f'Include {n_examples} example cases.' ' You must not include any characters in your response before the opening { of the JSON object, or after the closing }.' ), output_type=str, retries=1, ) result = await agent.run(extra_instructions or 'Please generate the object.') try: result = dataset_type.from_text(result.output, fmt='json', custom_evaluator_types=custom_evaluator_types) except ValidationError as e: # pragma: no cover print(f'Raw response from model:\n{result.output}') raise e if path is not None: result.to_file(path, custom_evaluator_types=custom_evaluator_types) # pragma: no cover return result ``` # `pydantic_evals.otel` ### SpanNode A node in the span tree; provides references to parents/children for easy traversal and queries. Source code in `pydantic_evals/pydantic_evals/otel/span_tree.py` ```python @dataclass(repr=False) class SpanNode: """A node in the span tree; provides references to parents/children for easy traversal and queries.""" name: str trace_id: int span_id: int parent_span_id: int | None start_timestamp: datetime end_timestamp: datetime attributes: dict[str, AttributeValue] @property def duration(self) -> timedelta: """Return the span's duration as a timedelta, or None if start/end not set.""" return self.end_timestamp - self.start_timestamp @property def children(self) -> list[SpanNode]: return list(self.children_by_id.values()) @property def descendants(self) -> list[SpanNode]: """Return all descendants of this node in DFS order.""" return self.find_descendants(lambda _: True) @property def ancestors(self) -> list[SpanNode]: """Return all ancestors of this node.""" return self.find_ancestors(lambda _: True) @property def node_key(self) -> str: return f'{self.trace_id:032x}:{self.span_id:016x}' @property def parent_node_key(self) -> str | None: return None if self.parent_span_id is None else f'{self.trace_id:032x}:{self.parent_span_id:016x}' # ------------------------------------------------------------------------- # Construction # ------------------------------------------------------------------------- def __post_init__(self): self.parent: SpanNode | None = None self.children_by_id: dict[str, SpanNode] = {} @staticmethod def from_readable_span(span: ReadableSpan) -> SpanNode: assert span.context is not None, 'Span has no context' assert span.start_time is not None, 'Span has no start time' assert span.end_time is not None, 'Span has no end time' return SpanNode( name=span.name, trace_id=span.context.trace_id, span_id=span.context.span_id, parent_span_id=span.parent.span_id if span.parent else None, start_timestamp=datetime.fromtimestamp(span.start_time / 1e9, tz=timezone.utc), end_timestamp=datetime.fromtimestamp(span.end_time / 1e9, tz=timezone.utc), attributes=dict(span.attributes or {}), ) def add_child(self, child: SpanNode) -> None: """Attach a child node to this node's list of children.""" assert child.trace_id == self.trace_id, f"traces don't match: {child.trace_id:032x} != {self.trace_id:032x}" assert child.parent_span_id == self.span_id, ( f'parent span mismatch: {child.parent_span_id:016x} != {self.span_id:016x}' ) self.children_by_id[child.node_key] = child child.parent = self # ------------------------------------------------------------------------- # Child queries # ------------------------------------------------------------------------- def find_children(self, predicate: SpanQuery | SpanPredicate) -> list[SpanNode]: """Return all immediate children that satisfy the given predicate.""" return list(self._filter_children(predicate)) def first_child(self, predicate: SpanQuery | SpanPredicate) -> SpanNode | None: """Return the first immediate child that satisfies the given predicate, or None if none match.""" return next(self._filter_children(predicate), None) def any_child(self, predicate: SpanQuery | SpanPredicate) -> bool: """Returns True if there is at least one child that satisfies the predicate.""" return self.first_child(predicate) is not None def _filter_children(self, predicate: SpanQuery | SpanPredicate) -> Iterator[SpanNode]: return (child for child in self.children if child.matches(predicate)) # ------------------------------------------------------------------------- # Descendant queries (DFS) # ------------------------------------------------------------------------- def find_descendants( self, predicate: SpanQuery | SpanPredicate, stop_recursing_when: SpanQuery | SpanPredicate | None = None ) -> list[SpanNode]: """Return all descendant nodes that satisfy the given predicate in DFS order.""" return list(self._filter_descendants(predicate, stop_recursing_when)) def first_descendant( self, predicate: SpanQuery | SpanPredicate, stop_recursing_when: SpanQuery | SpanPredicate | None = None ) -> SpanNode | None: """DFS: Return the first descendant (in DFS order) that satisfies the given predicate, or `None` if none match.""" return next(self._filter_descendants(predicate, stop_recursing_when), None) def any_descendant( self, predicate: SpanQuery | SpanPredicate, stop_recursing_when: SpanQuery | SpanPredicate | None = None ) -> bool: """Returns `True` if there is at least one descendant that satisfies the predicate.""" return self.first_descendant(predicate, stop_recursing_when) is not None def _filter_descendants( self, predicate: SpanQuery | SpanPredicate, stop_recursing_when: SpanQuery | SpanPredicate | None ) -> Iterator[SpanNode]: stack = list(self.children) while stack: node = stack.pop() if node.matches(predicate): yield node if stop_recursing_when is not None and node.matches(stop_recursing_when): continue stack.extend(node.children) # ------------------------------------------------------------------------- # Ancestor queries (DFS "up" the chain) # ------------------------------------------------------------------------- def find_ancestors( self, predicate: SpanQuery | SpanPredicate, stop_recursing_when: SpanQuery | SpanPredicate | None = None ) -> list[SpanNode]: """Return all ancestors that satisfy the given predicate.""" return list(self._filter_ancestors(predicate, stop_recursing_when)) def first_ancestor( self, predicate: SpanQuery | SpanPredicate, stop_recursing_when: SpanQuery | SpanPredicate | None = None ) -> SpanNode | None: """Return the closest ancestor that satisfies the given predicate, or `None` if none match.""" return next(self._filter_ancestors(predicate, stop_recursing_when), None) def any_ancestor( self, predicate: SpanQuery | SpanPredicate, stop_recursing_when: SpanQuery | SpanPredicate | None = None ) -> bool: """Returns True if any ancestor satisfies the predicate.""" return self.first_ancestor(predicate, stop_recursing_when) is not None def _filter_ancestors( self, predicate: SpanQuery | SpanPredicate, stop_recursing_when: SpanQuery | SpanPredicate | None ) -> Iterator[SpanNode]: node = self.parent while node: if node.matches(predicate): yield node if stop_recursing_when is not None and node.matches(stop_recursing_when): break node = node.parent # ------------------------------------------------------------------------- # Query matching # ------------------------------------------------------------------------- def matches(self, query: SpanQuery | SpanPredicate) -> bool: """Check if the span node matches the query conditions or predicate.""" if callable(query): return query(self) return self._matches_query(query) def _matches_query(self, query: SpanQuery) -> bool: # noqa C901 """Check if the span matches the query conditions.""" # Logical combinations if or_ := query.get('or_'): if len(query) > 1: raise ValueError("Cannot combine 'or_' conditions with other conditions at the same level") return any(self._matches_query(q) for q in or_) if not_ := query.get('not_'): if self._matches_query(not_): return False if and_ := query.get('and_'): results = [self._matches_query(q) for q in and_] if not all(results): return False # At this point, all existing ANDs and no existing ORs have passed, so it comes down to this condition # Name conditions if (name_equals := query.get('name_equals')) and self.name != name_equals: return False if (name_contains := query.get('name_contains')) and name_contains not in self.name: return False if (name_matches_regex := query.get('name_matches_regex')) and not re.match(name_matches_regex, self.name): return False # Attribute conditions if (has_attributes := query.get('has_attributes')) and not all( self.attributes.get(key) == value for key, value in has_attributes.items() ): return False if (has_attributes_keys := query.get('has_attribute_keys')) and not all( key in self.attributes for key in has_attributes_keys ): return False # Timing conditions if (min_duration := query.get('min_duration')) is not None: if not isinstance(min_duration, timedelta): min_duration = timedelta(seconds=min_duration) if self.duration < min_duration: return False if (max_duration := query.get('max_duration')) is not None: if not isinstance(max_duration, timedelta): max_duration = timedelta(seconds=max_duration) if self.duration > max_duration: return False # Children conditions if (min_child_count := query.get('min_child_count')) and len(self.children) < min_child_count: return False if (max_child_count := query.get('max_child_count')) and len(self.children) > max_child_count: return False if (some_child_has := query.get('some_child_has')) and not any( child._matches_query(some_child_has) for child in self.children ): return False if (all_children_have := query.get('all_children_have')) and not all( child._matches_query(all_children_have) for child in self.children ): return False if (no_child_has := query.get('no_child_has')) and any( child._matches_query(no_child_has) for child in self.children ): return False # Descendant conditions # The following local functions with cache decorators are used to avoid repeatedly evaluating these properties @cache def descendants(): return self.descendants @cache def pruned_descendants(): stop_recursing_when = query.get('stop_recursing_when') return ( self._filter_descendants(lambda _: True, stop_recursing_when) if stop_recursing_when else descendants() ) if (min_descendant_count := query.get('min_descendant_count')) and len(descendants()) < min_descendant_count: return False if (max_descendant_count := query.get('max_descendant_count')) and len(descendants()) > max_descendant_count: return False if (some_descendant_has := query.get('some_descendant_has')) and not any( descendant._matches_query(some_descendant_has) for descendant in pruned_descendants() ): return False if (all_descendants_have := query.get('all_descendants_have')) and not all( descendant._matches_query(all_descendants_have) for descendant in pruned_descendants() ): return False if (no_descendant_has := query.get('no_descendant_has')) and any( descendant._matches_query(no_descendant_has) for descendant in pruned_descendants() ): return False # Ancestor conditions # The following local functions with cache decorators are used to avoid repeatedly evaluating these properties @cache def ancestors(): return self.ancestors @cache def pruned_ancestors(): stop_recursing_when = query.get('stop_recursing_when') return self._filter_ancestors(lambda _: True, stop_recursing_when) if stop_recursing_when else ancestors() if (min_depth := query.get('min_depth')) and len(ancestors()) < min_depth: return False if (max_depth := query.get('max_depth')) and len(ancestors()) > max_depth: return False if (some_ancestor_has := query.get('some_ancestor_has')) and not any( ancestor._matches_query(some_ancestor_has) for ancestor in pruned_ancestors() ): return False if (all_ancestors_have := query.get('all_ancestors_have')) and not all( ancestor._matches_query(all_ancestors_have) for ancestor in pruned_ancestors() ): return False if (no_ancestor_has := query.get('no_ancestor_has')) and any( ancestor._matches_query(no_ancestor_has) for ancestor in pruned_ancestors() ): return False return True # ------------------------------------------------------------------------- # String representation # ------------------------------------------------------------------------- def repr_xml( self, include_children: bool = True, include_trace_id: bool = False, include_span_id: bool = False, include_start_timestamp: bool = False, include_duration: bool = False, ) -> str: """Return an XML-like string representation of the node. Optionally includes children, trace_id, span_id, start_timestamp, and duration. """ first_line_parts = [f'') for child in self.children: extra_lines.append( indent( child.repr_xml( include_children=include_children, include_trace_id=include_trace_id, include_span_id=include_span_id, include_start_timestamp=include_start_timestamp, include_duration=include_duration, ), ' ', ) ) extra_lines.append('') else: if self.children: first_line_parts.append('children=...') first_line_parts.append('/>') return '\n'.join([' '.join(first_line_parts), *extra_lines]) def __str__(self) -> str: if self.children: return f"..." else: return f"" def __repr__(self) -> str: return self.repr_xml() ``` #### duration ```python duration: timedelta ``` Return the span's duration as a timedelta, or None if start/end not set. #### descendants ```python descendants: list[SpanNode] ``` Return all descendants of this node in DFS order. #### ancestors ```python ancestors: list[SpanNode] ``` Return all ancestors of this node. #### add_child ```python add_child(child: SpanNode) -> None ``` Attach a child node to this node's list of children. Source code in `pydantic_evals/pydantic_evals/otel/span_tree.py` ```python def add_child(self, child: SpanNode) -> None: """Attach a child node to this node's list of children.""" assert child.trace_id == self.trace_id, f"traces don't match: {child.trace_id:032x} != {self.trace_id:032x}" assert child.parent_span_id == self.span_id, ( f'parent span mismatch: {child.parent_span_id:016x} != {self.span_id:016x}' ) self.children_by_id[child.node_key] = child child.parent = self ``` #### find_children ```python find_children( predicate: SpanQuery | SpanPredicate, ) -> list[SpanNode] ``` Return all immediate children that satisfy the given predicate. Source code in `pydantic_evals/pydantic_evals/otel/span_tree.py` ```python def find_children(self, predicate: SpanQuery | SpanPredicate) -> list[SpanNode]: """Return all immediate children that satisfy the given predicate.""" return list(self._filter_children(predicate)) ``` #### first_child ```python first_child( predicate: SpanQuery | SpanPredicate, ) -> SpanNode | None ``` Return the first immediate child that satisfies the given predicate, or None if none match. Source code in `pydantic_evals/pydantic_evals/otel/span_tree.py` ```python def first_child(self, predicate: SpanQuery | SpanPredicate) -> SpanNode | None: """Return the first immediate child that satisfies the given predicate, or None if none match.""" return next(self._filter_children(predicate), None) ``` #### any_child ```python any_child(predicate: SpanQuery | SpanPredicate) -> bool ``` Returns True if there is at least one child that satisfies the predicate. Source code in `pydantic_evals/pydantic_evals/otel/span_tree.py` ```python def any_child(self, predicate: SpanQuery | SpanPredicate) -> bool: """Returns True if there is at least one child that satisfies the predicate.""" return self.first_child(predicate) is not None ``` #### find_descendants ```python find_descendants( predicate: SpanQuery | SpanPredicate, stop_recursing_when: ( SpanQuery | SpanPredicate | None ) = None, ) -> list[SpanNode] ``` Return all descendant nodes that satisfy the given predicate in DFS order. Source code in `pydantic_evals/pydantic_evals/otel/span_tree.py` ```python def find_descendants( self, predicate: SpanQuery | SpanPredicate, stop_recursing_when: SpanQuery | SpanPredicate | None = None ) -> list[SpanNode]: """Return all descendant nodes that satisfy the given predicate in DFS order.""" return list(self._filter_descendants(predicate, stop_recursing_when)) ``` #### first_descendant ```python first_descendant( predicate: SpanQuery | SpanPredicate, stop_recursing_when: ( SpanQuery | SpanPredicate | None ) = None, ) -> SpanNode | None ``` DFS: Return the first descendant (in DFS order) that satisfies the given predicate, or `None` if none match. Source code in `pydantic_evals/pydantic_evals/otel/span_tree.py` ```python def first_descendant( self, predicate: SpanQuery | SpanPredicate, stop_recursing_when: SpanQuery | SpanPredicate | None = None ) -> SpanNode | None: """DFS: Return the first descendant (in DFS order) that satisfies the given predicate, or `None` if none match.""" return next(self._filter_descendants(predicate, stop_recursing_when), None) ``` #### any_descendant ```python any_descendant( predicate: SpanQuery | SpanPredicate, stop_recursing_when: ( SpanQuery | SpanPredicate | None ) = None, ) -> bool ``` Returns `True` if there is at least one descendant that satisfies the predicate. Source code in `pydantic_evals/pydantic_evals/otel/span_tree.py` ```python def any_descendant( self, predicate: SpanQuery | SpanPredicate, stop_recursing_when: SpanQuery | SpanPredicate | None = None ) -> bool: """Returns `True` if there is at least one descendant that satisfies the predicate.""" return self.first_descendant(predicate, stop_recursing_when) is not None ``` #### find_ancestors ```python find_ancestors( predicate: SpanQuery | SpanPredicate, stop_recursing_when: ( SpanQuery | SpanPredicate | None ) = None, ) -> list[SpanNode] ``` Return all ancestors that satisfy the given predicate. Source code in `pydantic_evals/pydantic_evals/otel/span_tree.py` ```python def find_ancestors( self, predicate: SpanQuery | SpanPredicate, stop_recursing_when: SpanQuery | SpanPredicate | None = None ) -> list[SpanNode]: """Return all ancestors that satisfy the given predicate.""" return list(self._filter_ancestors(predicate, stop_recursing_when)) ``` #### first_ancestor ```python first_ancestor( predicate: SpanQuery | SpanPredicate, stop_recursing_when: ( SpanQuery | SpanPredicate | None ) = None, ) -> SpanNode | None ``` Return the closest ancestor that satisfies the given predicate, or `None` if none match. Source code in `pydantic_evals/pydantic_evals/otel/span_tree.py` ```python def first_ancestor( self, predicate: SpanQuery | SpanPredicate, stop_recursing_when: SpanQuery | SpanPredicate | None = None ) -> SpanNode | None: """Return the closest ancestor that satisfies the given predicate, or `None` if none match.""" return next(self._filter_ancestors(predicate, stop_recursing_when), None) ``` #### any_ancestor ```python any_ancestor( predicate: SpanQuery | SpanPredicate, stop_recursing_when: ( SpanQuery | SpanPredicate | None ) = None, ) -> bool ``` Returns True if any ancestor satisfies the predicate. Source code in `pydantic_evals/pydantic_evals/otel/span_tree.py` ```python def any_ancestor( self, predicate: SpanQuery | SpanPredicate, stop_recursing_when: SpanQuery | SpanPredicate | None = None ) -> bool: """Returns True if any ancestor satisfies the predicate.""" return self.first_ancestor(predicate, stop_recursing_when) is not None ``` #### matches ```python matches(query: SpanQuery | SpanPredicate) -> bool ``` Check if the span node matches the query conditions or predicate. Source code in `pydantic_evals/pydantic_evals/otel/span_tree.py` ```python def matches(self, query: SpanQuery | SpanPredicate) -> bool: """Check if the span node matches the query conditions or predicate.""" if callable(query): return query(self) return self._matches_query(query) ``` #### repr_xml ```python repr_xml( include_children: bool = True, include_trace_id: bool = False, include_span_id: bool = False, include_start_timestamp: bool = False, include_duration: bool = False, ) -> str ``` Return an XML-like string representation of the node. Optionally includes children, trace_id, span_id, start_timestamp, and duration. Source code in `pydantic_evals/pydantic_evals/otel/span_tree.py` ```python def repr_xml( self, include_children: bool = True, include_trace_id: bool = False, include_span_id: bool = False, include_start_timestamp: bool = False, include_duration: bool = False, ) -> str: """Return an XML-like string representation of the node. Optionally includes children, trace_id, span_id, start_timestamp, and duration. """ first_line_parts = [f'') for child in self.children: extra_lines.append( indent( child.repr_xml( include_children=include_children, include_trace_id=include_trace_id, include_span_id=include_span_id, include_start_timestamp=include_start_timestamp, include_duration=include_duration, ), ' ', ) ) extra_lines.append('') else: if self.children: first_line_parts.append('children=...') first_line_parts.append('/>') return '\n'.join([' '.join(first_line_parts), *extra_lines]) ``` ### SpanQuery Bases: `TypedDict` A serializable query for filtering SpanNodes based on various conditions. All fields are optional and combined with AND logic by default. Source code in `pydantic_evals/pydantic_evals/otel/span_tree.py` ```python class SpanQuery(TypedDict, total=False): """A serializable query for filtering SpanNodes based on various conditions. All fields are optional and combined with AND logic by default. """ # These fields are ordered to match the implementation of SpanNode.matches_query for easy review. # * Individual span conditions come first because these are generally the cheapest to evaluate # * Logical combinations come next because they may just be combinations of individual span conditions # * Related-span conditions come last because they may require the most work to evaluate # Individual span conditions ## Name conditions name_equals: str name_contains: str name_matches_regex: str # regex pattern ## Attribute conditions has_attributes: dict[str, Any] has_attribute_keys: list[str] ## Timing conditions min_duration: timedelta | float max_duration: timedelta | float # Logical combinations of conditions not_: SpanQuery and_: list[SpanQuery] or_: list[SpanQuery] # Child conditions min_child_count: int max_child_count: int some_child_has: SpanQuery all_children_have: SpanQuery no_child_has: SpanQuery # Recursive conditions stop_recursing_when: SpanQuery """If present, stop recursing through ancestors or descendants at nodes that match this condition.""" ## Descendant conditions min_descendant_count: int max_descendant_count: int some_descendant_has: SpanQuery all_descendants_have: SpanQuery no_descendant_has: SpanQuery ## Ancestor conditions min_depth: int # depth is equivalent to ancestor count; roots have depth 0 max_depth: int some_ancestor_has: SpanQuery all_ancestors_have: SpanQuery no_ancestor_has: SpanQuery ``` #### stop_recursing_when ```python stop_recursing_when: SpanQuery ``` If present, stop recursing through ancestors or descendants at nodes that match this condition. ### SpanTree A container that builds a hierarchy of SpanNode objects from a list of finished spans. You can then search or iterate the tree to make your assertions (using DFS for traversal). Source code in `pydantic_evals/pydantic_evals/otel/span_tree.py` ```python @dataclass(repr=False) class SpanTree: """A container that builds a hierarchy of SpanNode objects from a list of finished spans. You can then search or iterate the tree to make your assertions (using DFS for traversal). """ roots: list[SpanNode] = field(default_factory=list) nodes_by_id: dict[str, SpanNode] = field(default_factory=dict) # ------------------------------------------------------------------------- # Construction # ------------------------------------------------------------------------- def __post_init__(self): self._rebuild_tree() def add_spans(self, spans: list[SpanNode]) -> None: """Add a list of spans to the tree, rebuilding the tree structure.""" for span in spans: self.nodes_by_id[span.node_key] = span self._rebuild_tree() def add_readable_spans(self, readable_spans: list[ReadableSpan]): self.add_spans([SpanNode.from_readable_span(span) for span in readable_spans]) def _rebuild_tree(self): # Ensure spans are ordered by start_timestamp so that roots and children end up in the right order nodes = list(self.nodes_by_id.values()) nodes.sort(key=lambda node: node.start_timestamp or datetime.min) self.nodes_by_id = {node.node_key: node for node in nodes} # Build the parent/child relationships for node in self.nodes_by_id.values(): parent_node_key = node.parent_node_key if parent_node_key is not None: parent_node = self.nodes_by_id.get(parent_node_key) if parent_node is not None: parent_node.add_child(node) # Determine the roots # A node is a "root" if its parent is None or if its parent's span_id is not in the current set of spans. self.roots = [] for node in self.nodes_by_id.values(): parent_node_key = node.parent_node_key if parent_node_key is None or parent_node_key not in self.nodes_by_id: self.roots.append(node) # ------------------------------------------------------------------------- # Node filtering and iteration # ------------------------------------------------------------------------- def find(self, predicate: SpanQuery | SpanPredicate) -> list[SpanNode]: """Find all nodes in the entire tree that match the predicate, scanning from each root in DFS order.""" return list(self._filter(predicate)) def first(self, predicate: SpanQuery | SpanPredicate) -> SpanNode | None: """Find the first node that matches a predicate, scanning from each root in DFS order. Returns `None` if not found.""" return next(self._filter(predicate), None) def any(self, predicate: SpanQuery | SpanPredicate) -> bool: """Returns True if any node in the tree matches the predicate.""" return self.first(predicate) is not None def _filter(self, predicate: SpanQuery | SpanPredicate) -> Iterator[SpanNode]: for node in self: if node.matches(predicate): yield node def __iter__(self) -> Iterator[SpanNode]: """Return an iterator over all nodes in the tree.""" return iter(self.nodes_by_id.values()) # ------------------------------------------------------------------------- # String representation # ------------------------------------------------------------------------- def repr_xml( self, include_children: bool = True, include_trace_id: bool = False, include_span_id: bool = False, include_start_timestamp: bool = False, include_duration: bool = False, ) -> str: """Return an XML-like string representation of the tree, optionally including children, trace_id, span_id, duration, and timestamps.""" if not self.roots: return '' repr_parts = [ '', *[ indent( root.repr_xml( include_children=include_children, include_trace_id=include_trace_id, include_span_id=include_span_id, include_start_timestamp=include_start_timestamp, include_duration=include_duration, ), ' ', ) for root in self.roots ], '', ] return '\n'.join(repr_parts) def __str__(self): return f'' def __repr__(self): return self.repr_xml() ``` #### add_spans ```python add_spans(spans: list[SpanNode]) -> None ``` Add a list of spans to the tree, rebuilding the tree structure. Source code in `pydantic_evals/pydantic_evals/otel/span_tree.py` ```python def add_spans(self, spans: list[SpanNode]) -> None: """Add a list of spans to the tree, rebuilding the tree structure.""" for span in spans: self.nodes_by_id[span.node_key] = span self._rebuild_tree() ``` #### find ```python find( predicate: SpanQuery | SpanPredicate, ) -> list[SpanNode] ``` Find all nodes in the entire tree that match the predicate, scanning from each root in DFS order. Source code in `pydantic_evals/pydantic_evals/otel/span_tree.py` ```python def find(self, predicate: SpanQuery | SpanPredicate) -> list[SpanNode]: """Find all nodes in the entire tree that match the predicate, scanning from each root in DFS order.""" return list(self._filter(predicate)) ``` #### first ```python first( predicate: SpanQuery | SpanPredicate, ) -> SpanNode | None ``` Find the first node that matches a predicate, scanning from each root in DFS order. Returns `None` if not found. Source code in `pydantic_evals/pydantic_evals/otel/span_tree.py` ```python def first(self, predicate: SpanQuery | SpanPredicate) -> SpanNode | None: """Find the first node that matches a predicate, scanning from each root in DFS order. Returns `None` if not found.""" return next(self._filter(predicate), None) ``` #### any ```python any(predicate: SpanQuery | SpanPredicate) -> bool ``` Returns True if any node in the tree matches the predicate. Source code in `pydantic_evals/pydantic_evals/otel/span_tree.py` ```python def any(self, predicate: SpanQuery | SpanPredicate) -> bool: """Returns True if any node in the tree matches the predicate.""" return self.first(predicate) is not None ``` #### __iter__ ```python __iter__() -> Iterator[SpanNode] ``` Return an iterator over all nodes in the tree. Source code in `pydantic_evals/pydantic_evals/otel/span_tree.py` ```python def __iter__(self) -> Iterator[SpanNode]: """Return an iterator over all nodes in the tree.""" return iter(self.nodes_by_id.values()) ``` #### repr_xml ```python repr_xml( include_children: bool = True, include_trace_id: bool = False, include_span_id: bool = False, include_start_timestamp: bool = False, include_duration: bool = False, ) -> str ``` Return an XML-like string representation of the tree, optionally including children, trace_id, span_id, duration, and timestamps. Source code in `pydantic_evals/pydantic_evals/otel/span_tree.py` ```python def repr_xml( self, include_children: bool = True, include_trace_id: bool = False, include_span_id: bool = False, include_start_timestamp: bool = False, include_duration: bool = False, ) -> str: """Return an XML-like string representation of the tree, optionally including children, trace_id, span_id, duration, and timestamps.""" if not self.roots: return '' repr_parts = [ '', *[ indent( root.repr_xml( include_children=include_children, include_trace_id=include_trace_id, include_span_id=include_span_id, include_start_timestamp=include_start_timestamp, include_duration=include_duration, ), ' ', ) for root in self.roots ], '', ] return '\n'.join(repr_parts) ``` # `pydantic_evals.reporting` ### ReportCase Bases: `BaseModel` A single case in an evaluation report. Source code in `pydantic_evals/pydantic_evals/reporting/__init__.py` ```python class ReportCase(BaseModel): """A single case in an evaluation report.""" name: str """The name of the [case][pydantic_evals.Case].""" inputs: Any """The inputs to the task, from [`Case.inputs`][pydantic_evals.Case.inputs].""" metadata: Any """Any metadata associated with the case, from [`Case.metadata`][pydantic_evals.Case.metadata].""" expected_output: Any """The expected output of the task, from [`Case.expected_output`][pydantic_evals.Case.expected_output].""" output: Any """The output of the task execution.""" metrics: dict[str, float | int] attributes: dict[str, Any] scores: dict[str, EvaluationResult[int | float]] = field(init=False) labels: dict[str, EvaluationResult[str]] = field(init=False) assertions: dict[str, EvaluationResult[bool]] = field(init=False) task_duration: float total_duration: float # includes evaluator execution time # TODO(DavidM): Drop these once we can reference child spans in details panel: trace_id: str span_id: str ``` #### name ```python name: str ``` The name of the case. #### inputs ```python inputs: Any ``` The inputs to the task, from Case.inputs. #### metadata ```python metadata: Any ``` Any metadata associated with the case, from Case.metadata. #### expected_output ```python expected_output: Any ``` The expected output of the task, from Case.expected_output. #### output ```python output: Any ``` The output of the task execution. ### ReportCaseAggregate Bases: `BaseModel` A synthetic case that summarizes a set of cases. Source code in `pydantic_evals/pydantic_evals/reporting/__init__.py` ```python class ReportCaseAggregate(BaseModel): """A synthetic case that summarizes a set of cases.""" name: str scores: dict[str, float | int] labels: dict[str, dict[str, float]] metrics: dict[str, float | int] assertions: float | None task_duration: float total_duration: float @staticmethod def average(cases: list[ReportCase]) -> ReportCaseAggregate: """Produce a synthetic "summary" case by averaging quantitative attributes.""" num_cases = len(cases) if num_cases == 0: return ReportCaseAggregate( name='Averages', scores={}, labels={}, metrics={}, assertions=None, task_duration=0.0, total_duration=0.0, ) def _scores_averages(scores_by_name: list[dict[str, int | float | bool]]) -> dict[str, float]: counts_by_name: dict[str, int] = defaultdict(int) sums_by_name: dict[str, float] = defaultdict(float) for sbn in scores_by_name: for name, score in sbn.items(): counts_by_name[name] += 1 sums_by_name[name] += score return {name: sums_by_name[name] / counts_by_name[name] for name in sums_by_name} def _labels_averages(labels_by_name: list[dict[str, str]]) -> dict[str, dict[str, float]]: counts_by_name: dict[str, int] = defaultdict(int) sums_by_name: dict[str, dict[str, float]] = defaultdict(lambda: defaultdict(float)) for lbn in labels_by_name: for name, label in lbn.items(): counts_by_name[name] += 1 sums_by_name[name][label] += 1 return { name: {value: count / counts_by_name[name] for value, count in sums_by_name[name].items()} for name in sums_by_name } average_task_duration = sum(case.task_duration for case in cases) / num_cases average_total_duration = sum(case.total_duration for case in cases) / num_cases # average_assertions: dict[str, float] = _scores_averages([{k: v.value for k, v in case.scores.items()} for case in cases]) average_scores: dict[str, float] = _scores_averages( [{k: v.value for k, v in case.scores.items()} for case in cases] ) average_labels: dict[str, dict[str, float]] = _labels_averages( [{k: v.value for k, v in case.labels.items()} for case in cases] ) average_metrics: dict[str, float] = _scores_averages([case.metrics for case in cases]) average_assertions: float | None = None n_assertions = sum(len(case.assertions) for case in cases) if n_assertions > 0: n_passing = sum(1 for case in cases for assertion in case.assertions.values() if assertion.value) average_assertions = n_passing / n_assertions return ReportCaseAggregate( name='Averages', scores=average_scores, labels=average_labels, metrics=average_metrics, assertions=average_assertions, task_duration=average_task_duration, total_duration=average_total_duration, ) ``` #### average ```python average(cases: list[ReportCase]) -> ReportCaseAggregate ``` Produce a synthetic "summary" case by averaging quantitative attributes. Source code in `pydantic_evals/pydantic_evals/reporting/__init__.py` ```python @staticmethod def average(cases: list[ReportCase]) -> ReportCaseAggregate: """Produce a synthetic "summary" case by averaging quantitative attributes.""" num_cases = len(cases) if num_cases == 0: return ReportCaseAggregate( name='Averages', scores={}, labels={}, metrics={}, assertions=None, task_duration=0.0, total_duration=0.0, ) def _scores_averages(scores_by_name: list[dict[str, int | float | bool]]) -> dict[str, float]: counts_by_name: dict[str, int] = defaultdict(int) sums_by_name: dict[str, float] = defaultdict(float) for sbn in scores_by_name: for name, score in sbn.items(): counts_by_name[name] += 1 sums_by_name[name] += score return {name: sums_by_name[name] / counts_by_name[name] for name in sums_by_name} def _labels_averages(labels_by_name: list[dict[str, str]]) -> dict[str, dict[str, float]]: counts_by_name: dict[str, int] = defaultdict(int) sums_by_name: dict[str, dict[str, float]] = defaultdict(lambda: defaultdict(float)) for lbn in labels_by_name: for name, label in lbn.items(): counts_by_name[name] += 1 sums_by_name[name][label] += 1 return { name: {value: count / counts_by_name[name] for value, count in sums_by_name[name].items()} for name in sums_by_name } average_task_duration = sum(case.task_duration for case in cases) / num_cases average_total_duration = sum(case.total_duration for case in cases) / num_cases # average_assertions: dict[str, float] = _scores_averages([{k: v.value for k, v in case.scores.items()} for case in cases]) average_scores: dict[str, float] = _scores_averages( [{k: v.value for k, v in case.scores.items()} for case in cases] ) average_labels: dict[str, dict[str, float]] = _labels_averages( [{k: v.value for k, v in case.labels.items()} for case in cases] ) average_metrics: dict[str, float] = _scores_averages([case.metrics for case in cases]) average_assertions: float | None = None n_assertions = sum(len(case.assertions) for case in cases) if n_assertions > 0: n_passing = sum(1 for case in cases for assertion in case.assertions.values() if assertion.value) average_assertions = n_passing / n_assertions return ReportCaseAggregate( name='Averages', scores=average_scores, labels=average_labels, metrics=average_metrics, assertions=average_assertions, task_duration=average_task_duration, total_duration=average_total_duration, ) ``` ### EvaluationReport Bases: `BaseModel` A report of the results of evaluating a model on a set of cases. Source code in `pydantic_evals/pydantic_evals/reporting/__init__.py` ```python class EvaluationReport(BaseModel): """A report of the results of evaluating a model on a set of cases.""" name: str """The name of the report.""" cases: list[ReportCase] """The cases in the report.""" def averages(self) -> ReportCaseAggregate: return ReportCaseAggregate.average(self.cases) def print( self, width: int | None = None, baseline: EvaluationReport | None = None, include_input: bool = False, include_metadata: bool = False, include_expected_output: bool = False, include_output: bool = False, include_durations: bool = True, include_total_duration: bool = False, include_removed_cases: bool = False, include_averages: bool = True, input_config: RenderValueConfig | None = None, metadata_config: RenderValueConfig | None = None, output_config: RenderValueConfig | None = None, score_configs: dict[str, RenderNumberConfig] | None = None, label_configs: dict[str, RenderValueConfig] | None = None, metric_configs: dict[str, RenderNumberConfig] | None = None, duration_config: RenderNumberConfig | None = None, ): # pragma: no cover """Print this report to the console, optionally comparing it to a baseline report. If you want more control over the output, use `console_table` instead and pass it to `rich.Console.print`. """ table = self.console_table( baseline=baseline, include_input=include_input, include_metadata=include_metadata, include_expected_output=include_expected_output, include_output=include_output, include_durations=include_durations, include_total_duration=include_total_duration, include_removed_cases=include_removed_cases, include_averages=include_averages, input_config=input_config, metadata_config=metadata_config, output_config=output_config, score_configs=score_configs, label_configs=label_configs, metric_configs=metric_configs, duration_config=duration_config, ) Console(width=width).print(table) def console_table( self, baseline: EvaluationReport | None = None, include_input: bool = False, include_metadata: bool = False, include_expected_output: bool = False, include_output: bool = False, include_durations: bool = True, include_total_duration: bool = False, include_removed_cases: bool = False, include_averages: bool = True, input_config: RenderValueConfig | None = None, metadata_config: RenderValueConfig | None = None, output_config: RenderValueConfig | None = None, score_configs: dict[str, RenderNumberConfig] | None = None, label_configs: dict[str, RenderValueConfig] | None = None, metric_configs: dict[str, RenderNumberConfig] | None = None, duration_config: RenderNumberConfig | None = None, ) -> Table: """Return a table containing the data from this report, or the diff between this report and a baseline report. Optionally include input and output details. """ renderer = EvaluationRenderer( include_input=include_input, include_metadata=include_metadata, include_expected_output=include_expected_output, include_output=include_output, include_durations=include_durations, include_total_duration=include_total_duration, include_removed_cases=include_removed_cases, include_averages=include_averages, input_config={**_DEFAULT_VALUE_CONFIG, **(input_config or {})}, metadata_config={**_DEFAULT_VALUE_CONFIG, **(metadata_config or {})}, output_config=output_config or _DEFAULT_VALUE_CONFIG, score_configs=score_configs or {}, label_configs=label_configs or {}, metric_configs=metric_configs or {}, duration_config=duration_config or _DEFAULT_DURATION_CONFIG, ) if baseline is None: return renderer.build_table(self) else: # pragma: no cover return renderer.build_diff_table(self, baseline) def __str__(self) -> str: # pragma: lax no cover """Return a string representation of the report.""" table = self.console_table() io_file = StringIO() Console(file=io_file).print(table) return io_file.getvalue() ``` #### name ```python name: str ``` The name of the report. #### cases ```python cases: list[ReportCase] ``` The cases in the report. #### print ```python print( width: int | None = None, baseline: EvaluationReport | None = None, include_input: bool = False, include_metadata: bool = False, include_expected_output: bool = False, include_output: bool = False, include_durations: bool = True, include_total_duration: bool = False, include_removed_cases: bool = False, include_averages: bool = True, input_config: RenderValueConfig | None = None, metadata_config: RenderValueConfig | None = None, output_config: RenderValueConfig | None = None, score_configs: ( dict[str, RenderNumberConfig] | None ) = None, label_configs: ( dict[str, RenderValueConfig] | None ) = None, metric_configs: ( dict[str, RenderNumberConfig] | None ) = None, duration_config: RenderNumberConfig | None = None, ) ``` Print this report to the console, optionally comparing it to a baseline report. If you want more control over the output, use `console_table` instead and pass it to `rich.Console.print`. Source code in `pydantic_evals/pydantic_evals/reporting/__init__.py` ```python def print( self, width: int | None = None, baseline: EvaluationReport | None = None, include_input: bool = False, include_metadata: bool = False, include_expected_output: bool = False, include_output: bool = False, include_durations: bool = True, include_total_duration: bool = False, include_removed_cases: bool = False, include_averages: bool = True, input_config: RenderValueConfig | None = None, metadata_config: RenderValueConfig | None = None, output_config: RenderValueConfig | None = None, score_configs: dict[str, RenderNumberConfig] | None = None, label_configs: dict[str, RenderValueConfig] | None = None, metric_configs: dict[str, RenderNumberConfig] | None = None, duration_config: RenderNumberConfig | None = None, ): # pragma: no cover """Print this report to the console, optionally comparing it to a baseline report. If you want more control over the output, use `console_table` instead and pass it to `rich.Console.print`. """ table = self.console_table( baseline=baseline, include_input=include_input, include_metadata=include_metadata, include_expected_output=include_expected_output, include_output=include_output, include_durations=include_durations, include_total_duration=include_total_duration, include_removed_cases=include_removed_cases, include_averages=include_averages, input_config=input_config, metadata_config=metadata_config, output_config=output_config, score_configs=score_configs, label_configs=label_configs, metric_configs=metric_configs, duration_config=duration_config, ) Console(width=width).print(table) ``` #### console_table ```python console_table( baseline: EvaluationReport | None = None, include_input: bool = False, include_metadata: bool = False, include_expected_output: bool = False, include_output: bool = False, include_durations: bool = True, include_total_duration: bool = False, include_removed_cases: bool = False, include_averages: bool = True, input_config: RenderValueConfig | None = None, metadata_config: RenderValueConfig | None = None, output_config: RenderValueConfig | None = None, score_configs: ( dict[str, RenderNumberConfig] | None ) = None, label_configs: ( dict[str, RenderValueConfig] | None ) = None, metric_configs: ( dict[str, RenderNumberConfig] | None ) = None, duration_config: RenderNumberConfig | None = None, ) -> Table ``` Return a table containing the data from this report, or the diff between this report and a baseline report. Optionally include input and output details. Source code in `pydantic_evals/pydantic_evals/reporting/__init__.py` ```python def console_table( self, baseline: EvaluationReport | None = None, include_input: bool = False, include_metadata: bool = False, include_expected_output: bool = False, include_output: bool = False, include_durations: bool = True, include_total_duration: bool = False, include_removed_cases: bool = False, include_averages: bool = True, input_config: RenderValueConfig | None = None, metadata_config: RenderValueConfig | None = None, output_config: RenderValueConfig | None = None, score_configs: dict[str, RenderNumberConfig] | None = None, label_configs: dict[str, RenderValueConfig] | None = None, metric_configs: dict[str, RenderNumberConfig] | None = None, duration_config: RenderNumberConfig | None = None, ) -> Table: """Return a table containing the data from this report, or the diff between this report and a baseline report. Optionally include input and output details. """ renderer = EvaluationRenderer( include_input=include_input, include_metadata=include_metadata, include_expected_output=include_expected_output, include_output=include_output, include_durations=include_durations, include_total_duration=include_total_duration, include_removed_cases=include_removed_cases, include_averages=include_averages, input_config={**_DEFAULT_VALUE_CONFIG, **(input_config or {})}, metadata_config={**_DEFAULT_VALUE_CONFIG, **(metadata_config or {})}, output_config=output_config or _DEFAULT_VALUE_CONFIG, score_configs=score_configs or {}, label_configs=label_configs or {}, metric_configs=metric_configs or {}, duration_config=duration_config or _DEFAULT_DURATION_CONFIG, ) if baseline is None: return renderer.build_table(self) else: # pragma: no cover return renderer.build_diff_table(self, baseline) ``` #### __str__ ```python __str__() -> str ``` Return a string representation of the report. Source code in `pydantic_evals/pydantic_evals/reporting/__init__.py` ```python def __str__(self) -> str: # pragma: lax no cover """Return a string representation of the report.""" table = self.console_table() io_file = StringIO() Console(file=io_file).print(table) return io_file.getvalue() ``` ### RenderValueConfig Bases: `TypedDict` A configuration for rendering a values in an Evaluation report. Source code in `pydantic_evals/pydantic_evals/reporting/__init__.py` ```python class RenderValueConfig(TypedDict, total=False): """A configuration for rendering a values in an Evaluation report.""" value_formatter: str | Callable[[Any], str] diff_checker: Callable[[Any, Any], bool] | None diff_formatter: Callable[[Any, Any], str | None] | None diff_style: str ``` ### RenderNumberConfig Bases: `TypedDict` A configuration for rendering a particular score or metric in an Evaluation report. See the implementation of `_RenderNumber` for more clarity on how these parameters affect the rendering. Source code in `pydantic_evals/pydantic_evals/reporting/__init__.py` ```python class RenderNumberConfig(TypedDict, total=False): """A configuration for rendering a particular score or metric in an Evaluation report. See the implementation of `_RenderNumber` for more clarity on how these parameters affect the rendering. """ value_formatter: str | Callable[[float | int], str] """The logic to use for formatting values. * If not provided, format as ints if all values are ints, otherwise at least one decimal place and at least four significant figures. * You can also use a custom string format spec, e.g. '{:.3f}' * You can also use a custom function, e.g. lambda x: f'{x:.3f}' """ diff_formatter: str | Callable[[float | int, float | int], str | None] | None """The logic to use for formatting details about the diff. The strings produced by the value_formatter will always be included in the reports, but the diff_formatter is used to produce additional text about the difference between the old and new values, such as the absolute or relative difference. * If not provided, format as ints if all values are ints, otherwise at least one decimal place and at least four significant figures, and will include the percentage change. * You can also use a custom string format spec, e.g. '{:+.3f}' * You can also use a custom function, e.g. lambda x: f'{x:+.3f}'. If this function returns None, no extra diff text will be added. * You can also use None to never generate extra diff text. """ diff_atol: float """The absolute tolerance for considering a difference "significant". A difference is "significant" if `abs(new - old) < self.diff_atol + self.diff_rtol * abs(old)`. If a difference is not significant, it will not have the diff styles applied. Note that we still show both the rendered before and after values in the diff any time they differ, even if the difference is not significant. (If the rendered values are exactly the same, we only show the value once.) If not provided, use 1e-6. """ diff_rtol: float """The relative tolerance for considering a difference "significant". See the description of `diff_atol` for more details about what makes a difference "significant". If not provided, use 0.001 if all values are ints, otherwise 0.05. """ diff_increase_style: str """The style to apply to diffed values that have a significant increase. See the description of `diff_atol` for more details about what makes a difference "significant". If not provided, use green for scores and red for metrics. You can also use arbitrary `rich` styles, such as "bold red". """ diff_decrease_style: str """The style to apply to diffed values that have significant decrease. See the description of `diff_atol` for more details about what makes a difference "significant". If not provided, use red for scores and green for metrics. You can also use arbitrary `rich` styles, such as "bold red". """ ``` #### value_formatter ```python value_formatter: str | Callable[[float | int], str] ``` The logic to use for formatting values. - If not provided, format as ints if all values are ints, otherwise at least one decimal place and at least four significant figures. - You can also use a custom string format spec, e.g. '{:.3f}' - You can also use a custom function, e.g. lambda x: f'{x:.3f}' #### diff_formatter ```python diff_formatter: ( str | Callable[[float | int, float | int], str | None] | None ) ``` The logic to use for formatting details about the diff. The strings produced by the value_formatter will always be included in the reports, but the diff_formatter is used to produce additional text about the difference between the old and new values, such as the absolute or relative difference. - If not provided, format as ints if all values are ints, otherwise at least one decimal place and at least four significant figures, and will include the percentage change. - You can also use a custom string format spec, e.g. '{:+.3f}' - You can also use a custom function, e.g. lambda x: f'{x:+.3f}'. If this function returns None, no extra diff text will be added. - You can also use None to never generate extra diff text. #### diff_atol ```python diff_atol: float ``` The absolute tolerance for considering a difference "significant". A difference is "significant" if `abs(new - old) < self.diff_atol + self.diff_rtol * abs(old)`. If a difference is not significant, it will not have the diff styles applied. Note that we still show both the rendered before and after values in the diff any time they differ, even if the difference is not significant. (If the rendered values are exactly the same, we only show the value once.) If not provided, use 1e-6. #### diff_rtol ```python diff_rtol: float ``` The relative tolerance for considering a difference "significant". See the description of `diff_atol` for more details about what makes a difference "significant". If not provided, use 0.001 if all values are ints, otherwise 0.05. #### diff_increase_style ```python diff_increase_style: str ``` The style to apply to diffed values that have a significant increase. See the description of `diff_atol` for more details about what makes a difference "significant". If not provided, use green for scores and red for metrics. You can also use arbitrary `rich` styles, such as "bold red". #### diff_decrease_style ```python diff_decrease_style: str ``` The style to apply to diffed values that have significant decrease. See the description of `diff_atol` for more details about what makes a difference "significant". If not provided, use red for scores and green for metrics. You can also use arbitrary `rich` styles, such as "bold red". ### EvaluationRenderer A class for rendering an EvalReport or the diff between two EvalReports. Source code in `pydantic_evals/pydantic_evals/reporting/__init__.py` ```python @dataclass class EvaluationRenderer: """A class for rendering an EvalReport or the diff between two EvalReports.""" # Columns to include include_input: bool include_metadata: bool include_expected_output: bool include_output: bool include_durations: bool include_total_duration: bool # Rows to include include_removed_cases: bool include_averages: bool input_config: RenderValueConfig metadata_config: RenderValueConfig output_config: RenderValueConfig score_configs: dict[str, RenderNumberConfig] label_configs: dict[str, RenderValueConfig] metric_configs: dict[str, RenderNumberConfig] duration_config: RenderNumberConfig def include_scores(self, report: EvaluationReport, baseline: EvaluationReport | None = None): return any(case.scores for case in self._all_cases(report, baseline)) def include_labels(self, report: EvaluationReport, baseline: EvaluationReport | None = None): return any(case.labels for case in self._all_cases(report, baseline)) def include_metrics(self, report: EvaluationReport, baseline: EvaluationReport | None = None): return any(case.metrics for case in self._all_cases(report, baseline)) def include_assertions(self, report: EvaluationReport, baseline: EvaluationReport | None = None): return any(case.assertions for case in self._all_cases(report, baseline)) def _all_cases(self, report: EvaluationReport, baseline: EvaluationReport | None) -> list[ReportCase]: if not baseline: return report.cases else: return report.cases + self._baseline_cases_to_include(report, baseline) def _baseline_cases_to_include(self, report: EvaluationReport, baseline: EvaluationReport) -> list[ReportCase]: if self.include_removed_cases: return baseline.cases report_case_names = {case.name for case in report.cases} return [case for case in baseline.cases if case.name in report_case_names] def _get_case_renderer( self, report: EvaluationReport, baseline: EvaluationReport | None = None ) -> ReportCaseRenderer: input_renderer = _ValueRenderer.from_config(self.input_config) metadata_renderer = _ValueRenderer.from_config(self.metadata_config) output_renderer = _ValueRenderer.from_config(self.output_config) score_renderers = self._infer_score_renderers(report, baseline) label_renderers = self._infer_label_renderers(report, baseline) metric_renderers = self._infer_metric_renderers(report, baseline) duration_renderer = _NumberRenderer.infer_from_config( self.duration_config, 'duration', [x.task_duration for x in self._all_cases(report, baseline)] ) return ReportCaseRenderer( include_input=self.include_input, include_metadata=self.include_metadata, include_expected_output=self.include_expected_output, include_output=self.include_output, include_scores=self.include_scores(report, baseline), include_labels=self.include_labels(report, baseline), include_metrics=self.include_metrics(report, baseline), include_assertions=self.include_assertions(report, baseline), include_durations=self.include_durations, include_total_duration=self.include_total_duration, input_renderer=input_renderer, metadata_renderer=metadata_renderer, output_renderer=output_renderer, score_renderers=score_renderers, label_renderers=label_renderers, metric_renderers=metric_renderers, duration_renderer=duration_renderer, ) def build_table(self, report: EvaluationReport) -> Table: case_renderer = self._get_case_renderer(report) table = case_renderer.build_base_table(f'Evaluation Summary: {report.name}') for case in report.cases: table.add_row(*case_renderer.build_row(case)) if self.include_averages: # pragma: no branch average = report.averages() table.add_row(*case_renderer.build_aggregate_row(average)) return table def build_diff_table(self, report: EvaluationReport, baseline: EvaluationReport) -> Table: report_cases = report.cases baseline_cases = self._baseline_cases_to_include(report, baseline) report_cases_by_id = {case.name: case for case in report_cases} baseline_cases_by_id = {case.name: case for case in baseline_cases} diff_cases: list[tuple[ReportCase, ReportCase]] = [] removed_cases: list[ReportCase] = [] added_cases: list[ReportCase] = [] for case_id in sorted(set(baseline_cases_by_id.keys()) | set(report_cases_by_id.keys())): maybe_baseline_case = baseline_cases_by_id.get(case_id) maybe_report_case = report_cases_by_id.get(case_id) if maybe_baseline_case and maybe_report_case: diff_cases.append((maybe_baseline_case, maybe_report_case)) elif maybe_baseline_case: removed_cases.append(maybe_baseline_case) elif maybe_report_case: added_cases.append(maybe_report_case) else: # pragma: no cover assert False, 'This should be unreachable' case_renderer = self._get_case_renderer(report, baseline) diff_name = baseline.name if baseline.name == report.name else f'{baseline.name} → {report.name}' table = case_renderer.build_base_table(f'Evaluation Diff: {diff_name}') for baseline_case, new_case in diff_cases: table.add_row(*case_renderer.build_diff_row(new_case, baseline_case)) for case in added_cases: row = case_renderer.build_row(case) row[0] = f'[green]+ Added Case[/]\n{row[0]}' table.add_row(*row) for case in removed_cases: row = case_renderer.build_row(case) row[0] = f'[red]- Removed Case[/]\n{row[0]}' table.add_row(*row) if self.include_averages: # pragma: no branch report_average = ReportCaseAggregate.average(report_cases) baseline_average = ReportCaseAggregate.average(baseline_cases) table.add_row(*case_renderer.build_diff_aggregate_row(report_average, baseline_average)) return table def _infer_score_renderers( self, report: EvaluationReport, baseline: EvaluationReport | None ) -> dict[str, _NumberRenderer]: all_cases = self._all_cases(report, baseline) values_by_name: dict[str, list[float | int]] = {} for case in all_cases: for k, score in case.scores.items(): values_by_name.setdefault(k, []).append(score.value) all_renderers: dict[str, _NumberRenderer] = {} for name, values in values_by_name.items(): merged_config = _DEFAULT_NUMBER_CONFIG.copy() merged_config.update(self.score_configs.get(name, {})) all_renderers[name] = _NumberRenderer.infer_from_config(merged_config, 'score', values) return all_renderers def _infer_label_renderers( self, report: EvaluationReport, baseline: EvaluationReport | None ) -> dict[str, _ValueRenderer]: all_cases = self._all_cases(report, baseline) all_names: set[str] = set() for case in all_cases: for k in case.labels: all_names.add(k) all_renderers: dict[str, _ValueRenderer] = {} for name in all_names: merged_config = _DEFAULT_VALUE_CONFIG.copy() merged_config.update(self.label_configs.get(name, {})) all_renderers[name] = _ValueRenderer.from_config(merged_config) return all_renderers def _infer_metric_renderers( self, report: EvaluationReport, baseline: EvaluationReport | None ) -> dict[str, _NumberRenderer]: all_cases = self._all_cases(report, baseline) values_by_name: dict[str, list[float | int]] = {} for case in all_cases: for k, v in case.metrics.items(): values_by_name.setdefault(k, []).append(v) all_renderers: dict[str, _NumberRenderer] = {} for name, values in values_by_name.items(): merged_config = _DEFAULT_NUMBER_CONFIG.copy() merged_config.update(self.metric_configs.get(name, {})) all_renderers[name] = _NumberRenderer.infer_from_config(merged_config, 'metric', values) return all_renderers def _infer_duration_renderer( self, report: EvaluationReport, baseline: EvaluationReport | None ) -> _NumberRenderer: # pragma: no cover all_cases = self._all_cases(report, baseline) all_durations = [x.task_duration for x in all_cases] if self.include_total_duration: all_durations += [x.total_duration for x in all_cases] return _NumberRenderer.infer_from_config(self.duration_config, 'duration', all_durations) ``` # `pydantic_graph.exceptions` ### GraphSetupError Bases: `TypeError` Error caused by an incorrectly configured graph. Source code in `pydantic_graph/pydantic_graph/exceptions.py` ```python class GraphSetupError(TypeError): """Error caused by an incorrectly configured graph.""" message: str """Description of the mistake.""" def __init__(self, message: str): self.message = message super().__init__(message) ``` #### message ```python message: str = message ``` Description of the mistake. ### GraphRuntimeError Bases: `RuntimeError` Error caused by an issue during graph execution. Source code in `pydantic_graph/pydantic_graph/exceptions.py` ```python class GraphRuntimeError(RuntimeError): """Error caused by an issue during graph execution.""" message: str """The error message.""" def __init__(self, message: str): self.message = message super().__init__(message) ``` #### message ```python message: str = message ``` The error message. ### GraphNodeStatusError Bases: `GraphRuntimeError` Error caused by trying to run a node that already has status `'running'`, `'success'`, or `'error'`. Source code in `pydantic_graph/pydantic_graph/exceptions.py` ```python class GraphNodeStatusError(GraphRuntimeError): """Error caused by trying to run a node that already has status `'running'`, `'success'`, or `'error'`.""" def __init__(self, actual_status: 'SnapshotStatus'): self.actual_status = actual_status super().__init__(f"Incorrect snapshot status {actual_status!r}, must be 'created' or 'pending'.") @classmethod def check(cls, status: 'SnapshotStatus') -> None: """Check if the status is valid.""" if status not in {'created', 'pending'}: raise cls(status) ``` #### check ```python check(status: SnapshotStatus) -> None ``` Check if the status is valid. Source code in `pydantic_graph/pydantic_graph/exceptions.py` ```python @classmethod def check(cls, status: 'SnapshotStatus') -> None: """Check if the status is valid.""" if status not in {'created', 'pending'}: raise cls(status) ``` # `pydantic_graph` ### Graph Bases: `Generic[StateT, DepsT, RunEndT]` Definition of a graph. In `pydantic-graph`, a graph is a collection of nodes that can be run in sequence. The nodes define their outgoing edges — e.g. which nodes may be run next, and thereby the structure of the graph. Here's a very simple example of a graph which increments a number by 1, but makes sure the number is never 42 at the end. never_42.py ```py from __future__ import annotations from dataclasses import dataclass from pydantic_graph import BaseNode, End, Graph, GraphRunContext @dataclass class MyState: number: int @dataclass class Increment(BaseNode[MyState]): async def run(self, ctx: GraphRunContext) -> Check42: ctx.state.number += 1 return Check42() @dataclass class Check42(BaseNode[MyState, None, int]): async def run(self, ctx: GraphRunContext) -> Increment | End[int]: if ctx.state.number == 42: return Increment() else: return End(ctx.state.number) never_42_graph = Graph(nodes=(Increment, Check42)) ``` *(This example is complete, it can be run "as is")* See run For an example of running graph, and mermaid_code for an example of generating a mermaid diagram from the graph. Source code in `pydantic_graph/pydantic_graph/graph.py` ````python @dataclass(init=False) class Graph(Generic[StateT, DepsT, RunEndT]): """Definition of a graph. In `pydantic-graph`, a graph is a collection of nodes that can be run in sequence. The nodes define their outgoing edges — e.g. which nodes may be run next, and thereby the structure of the graph. Here's a very simple example of a graph which increments a number by 1, but makes sure the number is never 42 at the end. ```py {title="never_42.py" noqa="I001" py="3.10"} from __future__ import annotations from dataclasses import dataclass from pydantic_graph import BaseNode, End, Graph, GraphRunContext @dataclass class MyState: number: int @dataclass class Increment(BaseNode[MyState]): async def run(self, ctx: GraphRunContext) -> Check42: ctx.state.number += 1 return Check42() @dataclass class Check42(BaseNode[MyState, None, int]): async def run(self, ctx: GraphRunContext) -> Increment | End[int]: if ctx.state.number == 42: return Increment() else: return End(ctx.state.number) never_42_graph = Graph(nodes=(Increment, Check42)) ``` _(This example is complete, it can be run "as is")_ See [`run`][pydantic_graph.graph.Graph.run] For an example of running graph, and [`mermaid_code`][pydantic_graph.graph.Graph.mermaid_code] for an example of generating a mermaid diagram from the graph. """ name: str | None node_defs: dict[str, NodeDef[StateT, DepsT, RunEndT]] _state_type: type[StateT] | _utils.Unset = field(repr=False) _run_end_type: type[RunEndT] | _utils.Unset = field(repr=False) auto_instrument: bool = field(repr=False) def __init__( self, *, nodes: Sequence[type[BaseNode[StateT, DepsT, RunEndT]]], name: str | None = None, state_type: type[StateT] | _utils.Unset = _utils.UNSET, run_end_type: type[RunEndT] | _utils.Unset = _utils.UNSET, auto_instrument: bool = True, ): """Create a graph from a sequence of nodes. Args: nodes: The nodes which make up the graph, nodes need to be unique and all be generic in the same state type. name: Optional name for the graph, if not provided the name will be inferred from the calling frame on the first call to a graph method. state_type: The type of the state for the graph, this can generally be inferred from `nodes`. run_end_type: The type of the result of running the graph, this can generally be inferred from `nodes`. auto_instrument: Whether to create a span for the graph run and the execution of each node's run method. """ self.name = name self._state_type = state_type self._run_end_type = run_end_type self.auto_instrument = auto_instrument parent_namespace = _utils.get_parent_namespace(inspect.currentframe()) self.node_defs = {} for node in nodes: self._register_node(node, parent_namespace) self._validate_edges() async def run( self, start_node: BaseNode[StateT, DepsT, RunEndT], *, state: StateT = None, deps: DepsT = None, persistence: BaseStatePersistence[StateT, RunEndT] | None = None, infer_name: bool = True, ) -> GraphRunResult[StateT, RunEndT]: """Run the graph from a starting node until it ends. Args: start_node: the first node to run, since the graph definition doesn't define the entry point in the graph, you need to provide the starting node. state: The initial state of the graph. deps: The dependencies of the graph. persistence: State persistence interface, defaults to [`SimpleStatePersistence`][pydantic_graph.SimpleStatePersistence] if `None`. infer_name: Whether to infer the graph name from the calling frame. Returns: A `GraphRunResult` containing information about the run, including its final result. Here's an example of running the graph from [above][pydantic_graph.graph.Graph]: ```py {title="run_never_42.py" noqa="I001" py="3.10" requires="never_42.py"} from never_42 import Increment, MyState, never_42_graph async def main(): state = MyState(1) await never_42_graph.run(Increment(), state=state) print(state) #> MyState(number=2) state = MyState(41) await never_42_graph.run(Increment(), state=state) print(state) #> MyState(number=43) ``` """ if infer_name and self.name is None: self._infer_name(inspect.currentframe()) async with self.iter( start_node, state=state, deps=deps, persistence=persistence, infer_name=False ) as graph_run: async for _node in graph_run: pass result = graph_run.result assert result is not None, 'GraphRun should have a result' return result def run_sync( self, start_node: BaseNode[StateT, DepsT, RunEndT], *, state: StateT = None, deps: DepsT = None, persistence: BaseStatePersistence[StateT, RunEndT] | None = None, infer_name: bool = True, ) -> GraphRunResult[StateT, RunEndT]: """Synchronously run the graph. This is a convenience method that wraps [`self.run`][pydantic_graph.Graph.run] with `loop.run_until_complete(...)`. You therefore can't use this method inside async code or if there's an active event loop. Args: start_node: the first node to run, since the graph definition doesn't define the entry point in the graph, you need to provide the starting node. state: The initial state of the graph. deps: The dependencies of the graph. persistence: State persistence interface, defaults to [`SimpleStatePersistence`][pydantic_graph.SimpleStatePersistence] if `None`. infer_name: Whether to infer the graph name from the calling frame. Returns: The result type from ending the run and the history of the run. """ if infer_name and self.name is None: # pragma: no branch self._infer_name(inspect.currentframe()) return _utils.get_event_loop().run_until_complete( self.run(start_node, state=state, deps=deps, persistence=persistence, infer_name=False) ) @asynccontextmanager async def iter( self, start_node: BaseNode[StateT, DepsT, RunEndT], *, state: StateT = None, deps: DepsT = None, persistence: BaseStatePersistence[StateT, RunEndT] | None = None, span: AbstractContextManager[AbstractSpan] | None = None, infer_name: bool = True, ) -> AsyncIterator[GraphRun[StateT, DepsT, RunEndT]]: """A contextmanager which can be used to iterate over the graph's nodes as they are executed. This method returns a `GraphRun` object which can be used to async-iterate over the nodes of this `Graph` as they are executed. This is the API to use if you want to record or interact with the nodes as the graph execution unfolds. The `GraphRun` can also be used to manually drive the graph execution by calling [`GraphRun.next`][pydantic_graph.graph.GraphRun.next]. The `GraphRun` provides access to the full run history, state, deps, and the final result of the run once it has completed. For more details, see the API documentation of [`GraphRun`][pydantic_graph.graph.GraphRun]. Args: start_node: the first node to run. Since the graph definition doesn't define the entry point in the graph, you need to provide the starting node. state: The initial state of the graph. deps: The dependencies of the graph. persistence: State persistence interface, defaults to [`SimpleStatePersistence`][pydantic_graph.SimpleStatePersistence] if `None`. span: The span to use for the graph run. If not provided, a new span will be created. infer_name: Whether to infer the graph name from the calling frame. Returns: A GraphRun that can be async iterated over to drive the graph to completion. """ if infer_name and self.name is None: # f_back because `asynccontextmanager` adds one frame if frame := inspect.currentframe(): # pragma: no branch self._infer_name(frame.f_back) if persistence is None: persistence = SimpleStatePersistence() persistence.set_graph_types(self) with ExitStack() as stack: entered_span: AbstractSpan | None = None if span is None: if self.auto_instrument: entered_span = stack.enter_context(logfire_api.span('run graph {graph.name}', graph=self)) else: entered_span = stack.enter_context(span) traceparent = None if entered_span is None else get_traceparent(entered_span) yield GraphRun[StateT, DepsT, RunEndT]( graph=self, start_node=start_node, persistence=persistence, state=state, deps=deps, traceparent=traceparent, ) @asynccontextmanager async def iter_from_persistence( self, persistence: BaseStatePersistence[StateT, RunEndT], *, deps: DepsT = None, span: AbstractContextManager[AbstractSpan] | None = None, infer_name: bool = True, ) -> AsyncIterator[GraphRun[StateT, DepsT, RunEndT]]: """A contextmanager to iterate over the graph's nodes as they are executed, created from a persistence object. This method has similar functionality to [`iter`][pydantic_graph.graph.Graph.iter], but instead of passing the node to run, it will restore the node and state from state persistence. Args: persistence: The state persistence interface to use. deps: The dependencies of the graph. span: The span to use for the graph run. If not provided, a new span will be created. infer_name: Whether to infer the graph name from the calling frame. Returns: A GraphRun that can be async iterated over to drive the graph to completion. """ if infer_name and self.name is None: # f_back because `asynccontextmanager` adds one frame if frame := inspect.currentframe(): # pragma: no branch self._infer_name(frame.f_back) persistence.set_graph_types(self) snapshot = await persistence.load_next() if snapshot is None: raise exceptions.GraphRuntimeError('Unable to restore snapshot from state persistence.') snapshot.node.set_snapshot_id(snapshot.id) if self.auto_instrument and span is None: # pragma: no branch span = logfire_api.span('run graph {graph.name}', graph=self) with ExitStack() as stack: entered_span = None if span is None else stack.enter_context(span) traceparent = None if entered_span is None else get_traceparent(entered_span) yield GraphRun[StateT, DepsT, RunEndT]( graph=self, start_node=snapshot.node, persistence=persistence, state=snapshot.state, deps=deps, snapshot_id=snapshot.id, traceparent=traceparent, ) async def initialize( self, node: BaseNode[StateT, DepsT, RunEndT], persistence: BaseStatePersistence[StateT, RunEndT], *, state: StateT = None, infer_name: bool = True, ) -> None: """Initialize a new graph run in persistence without running it. This is useful if you want to set up a graph run to be run later, e.g. via [`iter_from_persistence`][pydantic_graph.graph.Graph.iter_from_persistence]. Args: node: The node to run first. persistence: State persistence interface. state: The start state of the graph. infer_name: Whether to infer the graph name from the calling frame. """ if infer_name and self.name is None: self._infer_name(inspect.currentframe()) persistence.set_graph_types(self) await persistence.snapshot_node(state, node) @deprecated('`next` is deprecated, use `async with graph.iter(...) as run: run.next()` instead') async def next( self, node: BaseNode[StateT, DepsT, RunEndT], persistence: BaseStatePersistence[StateT, RunEndT], *, state: StateT = None, deps: DepsT = None, infer_name: bool = True, ) -> BaseNode[StateT, DepsT, Any] | End[RunEndT]: """Run a node in the graph and return the next node to run. Args: node: The node to run. persistence: State persistence interface, defaults to [`SimpleStatePersistence`][pydantic_graph.SimpleStatePersistence] if `None`. state: The current state of the graph. deps: The dependencies of the graph. infer_name: Whether to infer the graph name from the calling frame. Returns: The next node to run or [`End`][pydantic_graph.nodes.End] if the graph has finished. """ if infer_name and self.name is None: self._infer_name(inspect.currentframe()) persistence.set_graph_types(self) run = GraphRun[StateT, DepsT, RunEndT]( graph=self, start_node=node, persistence=persistence, state=state, deps=deps, traceparent=None, ) return await run.next(node) def mermaid_code( self, *, start_node: Sequence[mermaid.NodeIdent] | mermaid.NodeIdent | None = None, title: str | None | typing_extensions.Literal[False] = None, edge_labels: bool = True, notes: bool = True, highlighted_nodes: Sequence[mermaid.NodeIdent] | mermaid.NodeIdent | None = None, highlight_css: str = mermaid.DEFAULT_HIGHLIGHT_CSS, infer_name: bool = True, direction: mermaid.StateDiagramDirection | None = None, ) -> str: """Generate a diagram representing the graph as [mermaid](https://mermaid.js.org/) diagram. This method calls [`pydantic_graph.mermaid.generate_code`][pydantic_graph.mermaid.generate_code]. Args: start_node: The node or nodes which can start the graph. title: The title of the diagram, use `False` to not include a title. edge_labels: Whether to include edge labels. notes: Whether to include notes on each node. highlighted_nodes: Optional node or nodes to highlight. highlight_css: The CSS to use for highlighting nodes. infer_name: Whether to infer the graph name from the calling frame. direction: The direction of flow. Returns: The mermaid code for the graph, which can then be rendered as a diagram. Here's an example of generating a diagram for the graph from [above][pydantic_graph.graph.Graph]: ```py {title="mermaid_never_42.py" py="3.10" requires="never_42.py"} from never_42 import Increment, never_42_graph print(never_42_graph.mermaid_code(start_node=Increment)) ''' --- title: never_42_graph --- stateDiagram-v2 [*] --> Increment Increment --> Check42 Check42 --> Increment Check42 --> [*] ''' ``` The rendered diagram will look like this: ```mermaid --- title: never_42_graph --- stateDiagram-v2 [*] --> Increment Increment --> Check42 Check42 --> Increment Check42 --> [*] ``` """ if infer_name and self.name is None: self._infer_name(inspect.currentframe()) if title is None and self.name: title = self.name return mermaid.generate_code( self, start_node=start_node, highlighted_nodes=highlighted_nodes, highlight_css=highlight_css, title=title or None, edge_labels=edge_labels, notes=notes, direction=direction, ) def mermaid_image( self, infer_name: bool = True, **kwargs: typing_extensions.Unpack[mermaid.MermaidConfig] ) -> bytes: """Generate a diagram representing the graph as an image. The format and diagram can be customized using `kwargs`, see [`pydantic_graph.mermaid.MermaidConfig`][pydantic_graph.mermaid.MermaidConfig]. !!! note "Uses external service" This method makes a request to [mermaid.ink](https://mermaid.ink) to render the image, `mermaid.ink` is a free service not affiliated with Pydantic. Args: infer_name: Whether to infer the graph name from the calling frame. **kwargs: Additional arguments to pass to `mermaid.request_image`. Returns: The image bytes. """ if infer_name and self.name is None: self._infer_name(inspect.currentframe()) if 'title' not in kwargs and self.name: kwargs['title'] = self.name return mermaid.request_image(self, **kwargs) def mermaid_save( self, path: Path | str, /, *, infer_name: bool = True, **kwargs: typing_extensions.Unpack[mermaid.MermaidConfig] ) -> None: """Generate a diagram representing the graph and save it as an image. The format and diagram can be customized using `kwargs`, see [`pydantic_graph.mermaid.MermaidConfig`][pydantic_graph.mermaid.MermaidConfig]. !!! note "Uses external service" This method makes a request to [mermaid.ink](https://mermaid.ink) to render the image, `mermaid.ink` is a free service not affiliated with Pydantic. Args: path: The path to save the image to. infer_name: Whether to infer the graph name from the calling frame. **kwargs: Additional arguments to pass to `mermaid.save_image`. """ if infer_name and self.name is None: self._infer_name(inspect.currentframe()) if 'title' not in kwargs and self.name: kwargs['title'] = self.name mermaid.save_image(path, self, **kwargs) def get_nodes(self) -> Sequence[type[BaseNode[StateT, DepsT, RunEndT]]]: """Get the nodes in the graph.""" return [node_def.node for node_def in self.node_defs.values()] @cached_property def inferred_types(self) -> tuple[type[StateT], type[RunEndT]]: # Get the types of the state and run end from the graph. if _utils.is_set(self._state_type) and _utils.is_set(self._run_end_type): return self._state_type, self._run_end_type state_type = self._state_type run_end_type = self._run_end_type for node_def in self.node_defs.values(): for base in typing_extensions.get_original_bases(node_def.node): if typing_extensions.get_origin(base) is BaseNode: args = typing_extensions.get_args(base) if not _utils.is_set(state_type) and args: state_type = args[0] if not _utils.is_set(run_end_type) and len(args) == 3: t = args[2] if not typing_objects.is_never(t): run_end_type = t if _utils.is_set(state_type) and _utils.is_set(run_end_type): return state_type, run_end_type # pyright: ignore[reportReturnType] # break the inner (bases) loop break if not _utils.is_set(state_type): # pragma: no branch # state defaults to None, so use that if we can't infer it state_type = None if not _utils.is_set(run_end_type): # this happens if a graph has no return nodes, use None so any downstream errors are clear run_end_type = None return state_type, run_end_type # pyright: ignore[reportReturnType] def _register_node( self, node: type[BaseNode[StateT, DepsT, RunEndT]], parent_namespace: dict[str, Any] | None, ) -> None: node_id = node.get_node_id() if existing_node := self.node_defs.get(node_id): raise exceptions.GraphSetupError( f'Node ID `{node_id}` is not unique — found on {existing_node.node} and {node}' ) else: self.node_defs[node_id] = node.get_node_def(parent_namespace) def _validate_edges(self): known_node_ids = self.node_defs.keys() bad_edges: dict[str, list[str]] = {} for node_id, node_def in self.node_defs.items(): for edge in node_def.next_node_edges.keys(): if edge not in known_node_ids: bad_edges.setdefault(edge, []).append(f'`{node_id}`') if bad_edges: bad_edges_list = [f'`{k}` is referenced by {_utils.comma_and(v)}' for k, v in bad_edges.items()] if len(bad_edges_list) == 1: raise exceptions.GraphSetupError(f'{bad_edges_list[0]} but not included in the graph.') else: b = '\n'.join(f' {be}' for be in bad_edges_list) raise exceptions.GraphSetupError( f'Nodes are referenced in the graph but not included in the graph:\n{b}' ) def _infer_name(self, function_frame: types.FrameType | None) -> None: """Infer the agent name from the call frame. Usage should be `self._infer_name(inspect.currentframe())`. Copied from `Agent`. """ assert self.name is None, 'Name already set' if function_frame is not None and (parent_frame := function_frame.f_back): # pragma: no branch for name, item in parent_frame.f_locals.items(): if item is self: self.name = name return if parent_frame.f_locals != parent_frame.f_globals: # pragma: no branch # if we couldn't find the agent in locals and globals are a different dict, try globals for name, item in parent_frame.f_globals.items(): # pragma: no branch if item is self: self.name = name return ```` #### __init__ ```python __init__( *, nodes: Sequence[type[BaseNode[StateT, DepsT, RunEndT]]], name: str | None = None, state_type: type[StateT] | Unset = UNSET, run_end_type: type[RunEndT] | Unset = UNSET, auto_instrument: bool = True ) ``` Create a graph from a sequence of nodes. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `nodes` | `Sequence[type[BaseNode[StateT, DepsT, RunEndT]]]` | The nodes which make up the graph, nodes need to be unique and all be generic in the same state type. | *required* | | `name` | `str | None` | Optional name for the graph, if not provided the name will be inferred from the calling frame on the first call to a graph method. | `None` | | `state_type` | `type[StateT] | Unset` | The type of the state for the graph, this can generally be inferred from nodes. | `UNSET` | | `run_end_type` | `type[RunEndT] | Unset` | The type of the result of running the graph, this can generally be inferred from nodes. | `UNSET` | | `auto_instrument` | `bool` | Whether to create a span for the graph run and the execution of each node's run method. | `True` | Source code in `pydantic_graph/pydantic_graph/graph.py` ```python def __init__( self, *, nodes: Sequence[type[BaseNode[StateT, DepsT, RunEndT]]], name: str | None = None, state_type: type[StateT] | _utils.Unset = _utils.UNSET, run_end_type: type[RunEndT] | _utils.Unset = _utils.UNSET, auto_instrument: bool = True, ): """Create a graph from a sequence of nodes. Args: nodes: The nodes which make up the graph, nodes need to be unique and all be generic in the same state type. name: Optional name for the graph, if not provided the name will be inferred from the calling frame on the first call to a graph method. state_type: The type of the state for the graph, this can generally be inferred from `nodes`. run_end_type: The type of the result of running the graph, this can generally be inferred from `nodes`. auto_instrument: Whether to create a span for the graph run and the execution of each node's run method. """ self.name = name self._state_type = state_type self._run_end_type = run_end_type self.auto_instrument = auto_instrument parent_namespace = _utils.get_parent_namespace(inspect.currentframe()) self.node_defs = {} for node in nodes: self._register_node(node, parent_namespace) self._validate_edges() ``` #### run ```python run( start_node: BaseNode[StateT, DepsT, RunEndT], *, state: StateT = None, deps: DepsT = None, persistence: ( BaseStatePersistence[StateT, RunEndT] | None ) = None, infer_name: bool = True ) -> GraphRunResult[StateT, RunEndT] ``` Run the graph from a starting node until it ends. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `start_node` | `BaseNode[StateT, DepsT, RunEndT]` | the first node to run, since the graph definition doesn't define the entry point in the graph, you need to provide the starting node. | *required* | | `state` | `StateT` | The initial state of the graph. | `None` | | `deps` | `DepsT` | The dependencies of the graph. | `None` | | `persistence` | `BaseStatePersistence[StateT, RunEndT] | None` | State persistence interface, defaults to SimpleStatePersistence if None. | `None` | | `infer_name` | `bool` | Whether to infer the graph name from the calling frame. | `True` | Returns: | Type | Description | | --- | --- | | `GraphRunResult[StateT, RunEndT]` | A GraphRunResult containing information about the run, including its final result. | Here's an example of running the graph from above: run_never_42.py ```py from never_42 import Increment, MyState, never_42_graph async def main(): state = MyState(1) await never_42_graph.run(Increment(), state=state) print(state) #> MyState(number=2) state = MyState(41) await never_42_graph.run(Increment(), state=state) print(state) #> MyState(number=43) ``` Source code in `pydantic_graph/pydantic_graph/graph.py` ````python async def run( self, start_node: BaseNode[StateT, DepsT, RunEndT], *, state: StateT = None, deps: DepsT = None, persistence: BaseStatePersistence[StateT, RunEndT] | None = None, infer_name: bool = True, ) -> GraphRunResult[StateT, RunEndT]: """Run the graph from a starting node until it ends. Args: start_node: the first node to run, since the graph definition doesn't define the entry point in the graph, you need to provide the starting node. state: The initial state of the graph. deps: The dependencies of the graph. persistence: State persistence interface, defaults to [`SimpleStatePersistence`][pydantic_graph.SimpleStatePersistence] if `None`. infer_name: Whether to infer the graph name from the calling frame. Returns: A `GraphRunResult` containing information about the run, including its final result. Here's an example of running the graph from [above][pydantic_graph.graph.Graph]: ```py {title="run_never_42.py" noqa="I001" py="3.10" requires="never_42.py"} from never_42 import Increment, MyState, never_42_graph async def main(): state = MyState(1) await never_42_graph.run(Increment(), state=state) print(state) #> MyState(number=2) state = MyState(41) await never_42_graph.run(Increment(), state=state) print(state) #> MyState(number=43) ``` """ if infer_name and self.name is None: self._infer_name(inspect.currentframe()) async with self.iter( start_node, state=state, deps=deps, persistence=persistence, infer_name=False ) as graph_run: async for _node in graph_run: pass result = graph_run.result assert result is not None, 'GraphRun should have a result' return result ```` #### run_sync ```python run_sync( start_node: BaseNode[StateT, DepsT, RunEndT], *, state: StateT = None, deps: DepsT = None, persistence: ( BaseStatePersistence[StateT, RunEndT] | None ) = None, infer_name: bool = True ) -> GraphRunResult[StateT, RunEndT] ``` Synchronously run the graph. This is a convenience method that wraps self.run with `loop.run_until_complete(...)`. You therefore can't use this method inside async code or if there's an active event loop. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `start_node` | `BaseNode[StateT, DepsT, RunEndT]` | the first node to run, since the graph definition doesn't define the entry point in the graph, you need to provide the starting node. | *required* | | `state` | `StateT` | The initial state of the graph. | `None` | | `deps` | `DepsT` | The dependencies of the graph. | `None` | | `persistence` | `BaseStatePersistence[StateT, RunEndT] | None` | State persistence interface, defaults to SimpleStatePersistence if None. | `None` | | `infer_name` | `bool` | Whether to infer the graph name from the calling frame. | `True` | Returns: | Type | Description | | --- | --- | | `GraphRunResult[StateT, RunEndT]` | The result type from ending the run and the history of the run. | Source code in `pydantic_graph/pydantic_graph/graph.py` ```python def run_sync( self, start_node: BaseNode[StateT, DepsT, RunEndT], *, state: StateT = None, deps: DepsT = None, persistence: BaseStatePersistence[StateT, RunEndT] | None = None, infer_name: bool = True, ) -> GraphRunResult[StateT, RunEndT]: """Synchronously run the graph. This is a convenience method that wraps [`self.run`][pydantic_graph.Graph.run] with `loop.run_until_complete(...)`. You therefore can't use this method inside async code or if there's an active event loop. Args: start_node: the first node to run, since the graph definition doesn't define the entry point in the graph, you need to provide the starting node. state: The initial state of the graph. deps: The dependencies of the graph. persistence: State persistence interface, defaults to [`SimpleStatePersistence`][pydantic_graph.SimpleStatePersistence] if `None`. infer_name: Whether to infer the graph name from the calling frame. Returns: The result type from ending the run and the history of the run. """ if infer_name and self.name is None: # pragma: no branch self._infer_name(inspect.currentframe()) return _utils.get_event_loop().run_until_complete( self.run(start_node, state=state, deps=deps, persistence=persistence, infer_name=False) ) ``` #### iter ```python iter( start_node: BaseNode[StateT, DepsT, RunEndT], *, state: StateT = None, deps: DepsT = None, persistence: ( BaseStatePersistence[StateT, RunEndT] | None ) = None, span: ( AbstractContextManager[AbstractSpan] | None ) = None, infer_name: bool = True ) -> AsyncIterator[GraphRun[StateT, DepsT, RunEndT]] ``` A contextmanager which can be used to iterate over the graph's nodes as they are executed. This method returns a `GraphRun` object which can be used to async-iterate over the nodes of this `Graph` as they are executed. This is the API to use if you want to record or interact with the nodes as the graph execution unfolds. The `GraphRun` can also be used to manually drive the graph execution by calling GraphRun.next. The `GraphRun` provides access to the full run history, state, deps, and the final result of the run once it has completed. For more details, see the API documentation of GraphRun. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `start_node` | `BaseNode[StateT, DepsT, RunEndT]` | the first node to run. Since the graph definition doesn't define the entry point in the graph, you need to provide the starting node. | *required* | | `state` | `StateT` | The initial state of the graph. | `None` | | `deps` | `DepsT` | The dependencies of the graph. | `None` | | `persistence` | `BaseStatePersistence[StateT, RunEndT] | None` | State persistence interface, defaults to SimpleStatePersistence if None. | `None` | | `span` | `AbstractContextManager[AbstractSpan] | None` | The span to use for the graph run. If not provided, a new span will be created. | `None` | | `infer_name` | `bool` | Whether to infer the graph name from the calling frame. | `True` | Returns: A GraphRun that can be async iterated over to drive the graph to completion. Source code in `pydantic_graph/pydantic_graph/graph.py` ```python @asynccontextmanager async def iter( self, start_node: BaseNode[StateT, DepsT, RunEndT], *, state: StateT = None, deps: DepsT = None, persistence: BaseStatePersistence[StateT, RunEndT] | None = None, span: AbstractContextManager[AbstractSpan] | None = None, infer_name: bool = True, ) -> AsyncIterator[GraphRun[StateT, DepsT, RunEndT]]: """A contextmanager which can be used to iterate over the graph's nodes as they are executed. This method returns a `GraphRun` object which can be used to async-iterate over the nodes of this `Graph` as they are executed. This is the API to use if you want to record or interact with the nodes as the graph execution unfolds. The `GraphRun` can also be used to manually drive the graph execution by calling [`GraphRun.next`][pydantic_graph.graph.GraphRun.next]. The `GraphRun` provides access to the full run history, state, deps, and the final result of the run once it has completed. For more details, see the API documentation of [`GraphRun`][pydantic_graph.graph.GraphRun]. Args: start_node: the first node to run. Since the graph definition doesn't define the entry point in the graph, you need to provide the starting node. state: The initial state of the graph. deps: The dependencies of the graph. persistence: State persistence interface, defaults to [`SimpleStatePersistence`][pydantic_graph.SimpleStatePersistence] if `None`. span: The span to use for the graph run. If not provided, a new span will be created. infer_name: Whether to infer the graph name from the calling frame. Returns: A GraphRun that can be async iterated over to drive the graph to completion. """ if infer_name and self.name is None: # f_back because `asynccontextmanager` adds one frame if frame := inspect.currentframe(): # pragma: no branch self._infer_name(frame.f_back) if persistence is None: persistence = SimpleStatePersistence() persistence.set_graph_types(self) with ExitStack() as stack: entered_span: AbstractSpan | None = None if span is None: if self.auto_instrument: entered_span = stack.enter_context(logfire_api.span('run graph {graph.name}', graph=self)) else: entered_span = stack.enter_context(span) traceparent = None if entered_span is None else get_traceparent(entered_span) yield GraphRun[StateT, DepsT, RunEndT]( graph=self, start_node=start_node, persistence=persistence, state=state, deps=deps, traceparent=traceparent, ) ``` #### iter_from_persistence ```python iter_from_persistence( persistence: BaseStatePersistence[StateT, RunEndT], *, deps: DepsT = None, span: ( AbstractContextManager[AbstractSpan] | None ) = None, infer_name: bool = True ) -> AsyncIterator[GraphRun[StateT, DepsT, RunEndT]] ``` A contextmanager to iterate over the graph's nodes as they are executed, created from a persistence object. This method has similar functionality to iter, but instead of passing the node to run, it will restore the node and state from state persistence. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `persistence` | `BaseStatePersistence[StateT, RunEndT]` | The state persistence interface to use. | *required* | | `deps` | `DepsT` | The dependencies of the graph. | `None` | | `span` | `AbstractContextManager[AbstractSpan] | None` | The span to use for the graph run. If not provided, a new span will be created. | `None` | | `infer_name` | `bool` | Whether to infer the graph name from the calling frame. | `True` | Returns: A GraphRun that can be async iterated over to drive the graph to completion. Source code in `pydantic_graph/pydantic_graph/graph.py` ```python @asynccontextmanager async def iter_from_persistence( self, persistence: BaseStatePersistence[StateT, RunEndT], *, deps: DepsT = None, span: AbstractContextManager[AbstractSpan] | None = None, infer_name: bool = True, ) -> AsyncIterator[GraphRun[StateT, DepsT, RunEndT]]: """A contextmanager to iterate over the graph's nodes as they are executed, created from a persistence object. This method has similar functionality to [`iter`][pydantic_graph.graph.Graph.iter], but instead of passing the node to run, it will restore the node and state from state persistence. Args: persistence: The state persistence interface to use. deps: The dependencies of the graph. span: The span to use for the graph run. If not provided, a new span will be created. infer_name: Whether to infer the graph name from the calling frame. Returns: A GraphRun that can be async iterated over to drive the graph to completion. """ if infer_name and self.name is None: # f_back because `asynccontextmanager` adds one frame if frame := inspect.currentframe(): # pragma: no branch self._infer_name(frame.f_back) persistence.set_graph_types(self) snapshot = await persistence.load_next() if snapshot is None: raise exceptions.GraphRuntimeError('Unable to restore snapshot from state persistence.') snapshot.node.set_snapshot_id(snapshot.id) if self.auto_instrument and span is None: # pragma: no branch span = logfire_api.span('run graph {graph.name}', graph=self) with ExitStack() as stack: entered_span = None if span is None else stack.enter_context(span) traceparent = None if entered_span is None else get_traceparent(entered_span) yield GraphRun[StateT, DepsT, RunEndT]( graph=self, start_node=snapshot.node, persistence=persistence, state=snapshot.state, deps=deps, snapshot_id=snapshot.id, traceparent=traceparent, ) ``` #### initialize ```python initialize( node: BaseNode[StateT, DepsT, RunEndT], persistence: BaseStatePersistence[StateT, RunEndT], *, state: StateT = None, infer_name: bool = True ) -> None ``` Initialize a new graph run in persistence without running it. This is useful if you want to set up a graph run to be run later, e.g. via iter_from_persistence. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `node` | `BaseNode[StateT, DepsT, RunEndT]` | The node to run first. | *required* | | `persistence` | `BaseStatePersistence[StateT, RunEndT]` | State persistence interface. | *required* | | `state` | `StateT` | The start state of the graph. | `None` | | `infer_name` | `bool` | Whether to infer the graph name from the calling frame. | `True` | Source code in `pydantic_graph/pydantic_graph/graph.py` ```python async def initialize( self, node: BaseNode[StateT, DepsT, RunEndT], persistence: BaseStatePersistence[StateT, RunEndT], *, state: StateT = None, infer_name: bool = True, ) -> None: """Initialize a new graph run in persistence without running it. This is useful if you want to set up a graph run to be run later, e.g. via [`iter_from_persistence`][pydantic_graph.graph.Graph.iter_from_persistence]. Args: node: The node to run first. persistence: State persistence interface. state: The start state of the graph. infer_name: Whether to infer the graph name from the calling frame. """ if infer_name and self.name is None: self._infer_name(inspect.currentframe()) persistence.set_graph_types(self) await persistence.snapshot_node(state, node) ``` #### next ```python next( node: BaseNode[StateT, DepsT, RunEndT], persistence: BaseStatePersistence[StateT, RunEndT], *, state: StateT = None, deps: DepsT = None, infer_name: bool = True ) -> BaseNode[StateT, DepsT, Any] | End[RunEndT] ``` Run a node in the graph and return the next node to run. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `node` | `BaseNode[StateT, DepsT, RunEndT]` | The node to run. | *required* | | `persistence` | `BaseStatePersistence[StateT, RunEndT]` | State persistence interface, defaults to SimpleStatePersistence if None. | *required* | | `state` | `StateT` | The current state of the graph. | `None` | | `deps` | `DepsT` | The dependencies of the graph. | `None` | | `infer_name` | `bool` | Whether to infer the graph name from the calling frame. | `True` | Returns: | Type | Description | | --- | --- | | `BaseNode[StateT, DepsT, Any] | End[RunEndT]` | The next node to run or End if the graph has finished. | Source code in `pydantic_graph/pydantic_graph/graph.py` ```python @deprecated('`next` is deprecated, use `async with graph.iter(...) as run: run.next()` instead') async def next( self, node: BaseNode[StateT, DepsT, RunEndT], persistence: BaseStatePersistence[StateT, RunEndT], *, state: StateT = None, deps: DepsT = None, infer_name: bool = True, ) -> BaseNode[StateT, DepsT, Any] | End[RunEndT]: """Run a node in the graph and return the next node to run. Args: node: The node to run. persistence: State persistence interface, defaults to [`SimpleStatePersistence`][pydantic_graph.SimpleStatePersistence] if `None`. state: The current state of the graph. deps: The dependencies of the graph. infer_name: Whether to infer the graph name from the calling frame. Returns: The next node to run or [`End`][pydantic_graph.nodes.End] if the graph has finished. """ if infer_name and self.name is None: self._infer_name(inspect.currentframe()) persistence.set_graph_types(self) run = GraphRun[StateT, DepsT, RunEndT]( graph=self, start_node=node, persistence=persistence, state=state, deps=deps, traceparent=None, ) return await run.next(node) ``` #### mermaid_code ```python mermaid_code( *, start_node: ( Sequence[NodeIdent] | NodeIdent | None ) = None, title: str | None | Literal[False] = None, edge_labels: bool = True, notes: bool = True, highlighted_nodes: ( Sequence[NodeIdent] | NodeIdent | None ) = None, highlight_css: str = DEFAULT_HIGHLIGHT_CSS, infer_name: bool = True, direction: StateDiagramDirection | None = None ) -> str ``` Generate a diagram representing the graph as [mermaid](https://mermaid.js.org/) diagram. This method calls pydantic_graph.mermaid.generate_code. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `start_node` | `Sequence[NodeIdent] | NodeIdent | None` | The node or nodes which can start the graph. | `None` | | `title` | `str | None | Literal[False]` | The title of the diagram, use False to not include a title. | `None` | | `edge_labels` | `bool` | Whether to include edge labels. | `True` | | `notes` | `bool` | Whether to include notes on each node. | `True` | | `highlighted_nodes` | `Sequence[NodeIdent] | NodeIdent | None` | Optional node or nodes to highlight. | `None` | | `highlight_css` | `str` | The CSS to use for highlighting nodes. | `DEFAULT_HIGHLIGHT_CSS` | | `infer_name` | `bool` | Whether to infer the graph name from the calling frame. | `True` | | `direction` | `StateDiagramDirection | None` | The direction of flow. | `None` | Returns: | Type | Description | | --- | --- | | `str` | The mermaid code for the graph, which can then be rendered as a diagram. | Here's an example of generating a diagram for the graph from above: mermaid_never_42.py ```py from never_42 import Increment, never_42_graph print(never_42_graph.mermaid_code(start_node=Increment)) ''' --- title: never_42_graph --- stateDiagram-v2 [*] --> Increment Increment --> Check42 Check42 --> Increment Check42 --> [*] ''' ``` The rendered diagram will look like this: ``` --- title: never_42_graph --- stateDiagram-v2 [*] --> Increment Increment --> Check42 Check42 --> Increment Check42 --> [*] ``` Source code in `pydantic_graph/pydantic_graph/graph.py` ````python def mermaid_code( self, *, start_node: Sequence[mermaid.NodeIdent] | mermaid.NodeIdent | None = None, title: str | None | typing_extensions.Literal[False] = None, edge_labels: bool = True, notes: bool = True, highlighted_nodes: Sequence[mermaid.NodeIdent] | mermaid.NodeIdent | None = None, highlight_css: str = mermaid.DEFAULT_HIGHLIGHT_CSS, infer_name: bool = True, direction: mermaid.StateDiagramDirection | None = None, ) -> str: """Generate a diagram representing the graph as [mermaid](https://mermaid.js.org/) diagram. This method calls [`pydantic_graph.mermaid.generate_code`][pydantic_graph.mermaid.generate_code]. Args: start_node: The node or nodes which can start the graph. title: The title of the diagram, use `False` to not include a title. edge_labels: Whether to include edge labels. notes: Whether to include notes on each node. highlighted_nodes: Optional node or nodes to highlight. highlight_css: The CSS to use for highlighting nodes. infer_name: Whether to infer the graph name from the calling frame. direction: The direction of flow. Returns: The mermaid code for the graph, which can then be rendered as a diagram. Here's an example of generating a diagram for the graph from [above][pydantic_graph.graph.Graph]: ```py {title="mermaid_never_42.py" py="3.10" requires="never_42.py"} from never_42 import Increment, never_42_graph print(never_42_graph.mermaid_code(start_node=Increment)) ''' --- title: never_42_graph --- stateDiagram-v2 [*] --> Increment Increment --> Check42 Check42 --> Increment Check42 --> [*] ''' ``` The rendered diagram will look like this: ```mermaid --- title: never_42_graph --- stateDiagram-v2 [*] --> Increment Increment --> Check42 Check42 --> Increment Check42 --> [*] ``` """ if infer_name and self.name is None: self._infer_name(inspect.currentframe()) if title is None and self.name: title = self.name return mermaid.generate_code( self, start_node=start_node, highlighted_nodes=highlighted_nodes, highlight_css=highlight_css, title=title or None, edge_labels=edge_labels, notes=notes, direction=direction, ) ```` #### mermaid_image ```python mermaid_image( infer_name: bool = True, **kwargs: Unpack[MermaidConfig] ) -> bytes ``` Generate a diagram representing the graph as an image. The format and diagram can be customized using `kwargs`, see pydantic_graph.mermaid.MermaidConfig. Uses external service This method makes a request to [mermaid.ink](https://mermaid.ink) to render the image, `mermaid.ink` is a free service not affiliated with Pydantic. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `infer_name` | `bool` | Whether to infer the graph name from the calling frame. | `True` | | `**kwargs` | `Unpack[MermaidConfig]` | Additional arguments to pass to mermaid.request_image. | `{}` | Returns: | Type | Description | | --- | --- | | `bytes` | The image bytes. | Source code in `pydantic_graph/pydantic_graph/graph.py` ```python def mermaid_image( self, infer_name: bool = True, **kwargs: typing_extensions.Unpack[mermaid.MermaidConfig] ) -> bytes: """Generate a diagram representing the graph as an image. The format and diagram can be customized using `kwargs`, see [`pydantic_graph.mermaid.MermaidConfig`][pydantic_graph.mermaid.MermaidConfig]. !!! note "Uses external service" This method makes a request to [mermaid.ink](https://mermaid.ink) to render the image, `mermaid.ink` is a free service not affiliated with Pydantic. Args: infer_name: Whether to infer the graph name from the calling frame. **kwargs: Additional arguments to pass to `mermaid.request_image`. Returns: The image bytes. """ if infer_name and self.name is None: self._infer_name(inspect.currentframe()) if 'title' not in kwargs and self.name: kwargs['title'] = self.name return mermaid.request_image(self, **kwargs) ``` #### mermaid_save ```python mermaid_save( path: Path | str, /, *, infer_name: bool = True, **kwargs: Unpack[MermaidConfig], ) -> None ``` Generate a diagram representing the graph and save it as an image. The format and diagram can be customized using `kwargs`, see pydantic_graph.mermaid.MermaidConfig. Uses external service This method makes a request to [mermaid.ink](https://mermaid.ink) to render the image, `mermaid.ink` is a free service not affiliated with Pydantic. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `path` | `Path | str` | The path to save the image to. | *required* | | `infer_name` | `bool` | Whether to infer the graph name from the calling frame. | `True` | | `**kwargs` | `Unpack[MermaidConfig]` | Additional arguments to pass to mermaid.save_image. | `{}` | Source code in `pydantic_graph/pydantic_graph/graph.py` ```python def mermaid_save( self, path: Path | str, /, *, infer_name: bool = True, **kwargs: typing_extensions.Unpack[mermaid.MermaidConfig] ) -> None: """Generate a diagram representing the graph and save it as an image. The format and diagram can be customized using `kwargs`, see [`pydantic_graph.mermaid.MermaidConfig`][pydantic_graph.mermaid.MermaidConfig]. !!! note "Uses external service" This method makes a request to [mermaid.ink](https://mermaid.ink) to render the image, `mermaid.ink` is a free service not affiliated with Pydantic. Args: path: The path to save the image to. infer_name: Whether to infer the graph name from the calling frame. **kwargs: Additional arguments to pass to `mermaid.save_image`. """ if infer_name and self.name is None: self._infer_name(inspect.currentframe()) if 'title' not in kwargs and self.name: kwargs['title'] = self.name mermaid.save_image(path, self, **kwargs) ``` #### get_nodes ```python get_nodes() -> ( Sequence[type[BaseNode[StateT, DepsT, RunEndT]]] ) ``` Get the nodes in the graph. Source code in `pydantic_graph/pydantic_graph/graph.py` ```python def get_nodes(self) -> Sequence[type[BaseNode[StateT, DepsT, RunEndT]]]: """Get the nodes in the graph.""" return [node_def.node for node_def in self.node_defs.values()] ``` ### GraphRun Bases: `Generic[StateT, DepsT, RunEndT]` A stateful, async-iterable run of a Graph. You typically get a `GraphRun` instance from calling `async with [my_graph.iter(...)][pydantic_graph.graph.Graph.iter] as graph_run:`. That gives you the ability to iterate through nodes as they run, either by `async for` iteration or by repeatedly calling `.next(...)`. Here's an example of iterating over the graph from above: iter_never_42.py ```py from copy import deepcopy from never_42 import Increment, MyState, never_42_graph async def main(): state = MyState(1) async with never_42_graph.iter(Increment(), state=state) as graph_run: node_states = [(graph_run.next_node, deepcopy(graph_run.state))] async for node in graph_run: node_states.append((node, deepcopy(graph_run.state))) print(node_states) ''' [ (Increment(), MyState(number=1)), (Increment(), MyState(number=1)), (Check42(), MyState(number=2)), (End(data=2), MyState(number=2)), ] ''' state = MyState(41) async with never_42_graph.iter(Increment(), state=state) as graph_run: node_states = [(graph_run.next_node, deepcopy(graph_run.state))] async for node in graph_run: node_states.append((node, deepcopy(graph_run.state))) print(node_states) ''' [ (Increment(), MyState(number=41)), (Increment(), MyState(number=41)), (Check42(), MyState(number=42)), (Increment(), MyState(number=42)), (Check42(), MyState(number=43)), (End(data=43), MyState(number=43)), ] ''' ``` See the GraphRun.next documentation for an example of how to manually drive the graph run. Source code in `pydantic_graph/pydantic_graph/graph.py` ````python class GraphRun(Generic[StateT, DepsT, RunEndT]): """A stateful, async-iterable run of a [`Graph`][pydantic_graph.graph.Graph]. You typically get a `GraphRun` instance from calling `async with [my_graph.iter(...)][pydantic_graph.graph.Graph.iter] as graph_run:`. That gives you the ability to iterate through nodes as they run, either by `async for` iteration or by repeatedly calling `.next(...)`. Here's an example of iterating over the graph from [above][pydantic_graph.graph.Graph]: ```py {title="iter_never_42.py" noqa="I001" py="3.10" requires="never_42.py"} from copy import deepcopy from never_42 import Increment, MyState, never_42_graph async def main(): state = MyState(1) async with never_42_graph.iter(Increment(), state=state) as graph_run: node_states = [(graph_run.next_node, deepcopy(graph_run.state))] async for node in graph_run: node_states.append((node, deepcopy(graph_run.state))) print(node_states) ''' [ (Increment(), MyState(number=1)), (Increment(), MyState(number=1)), (Check42(), MyState(number=2)), (End(data=2), MyState(number=2)), ] ''' state = MyState(41) async with never_42_graph.iter(Increment(), state=state) as graph_run: node_states = [(graph_run.next_node, deepcopy(graph_run.state))] async for node in graph_run: node_states.append((node, deepcopy(graph_run.state))) print(node_states) ''' [ (Increment(), MyState(number=41)), (Increment(), MyState(number=41)), (Check42(), MyState(number=42)), (Increment(), MyState(number=42)), (Check42(), MyState(number=43)), (End(data=43), MyState(number=43)), ] ''' ``` See the [`GraphRun.next` documentation][pydantic_graph.graph.GraphRun.next] for an example of how to manually drive the graph run. """ def __init__( self, *, graph: Graph[StateT, DepsT, RunEndT], start_node: BaseNode[StateT, DepsT, RunEndT], persistence: BaseStatePersistence[StateT, RunEndT], state: StateT, deps: DepsT, traceparent: str | None, snapshot_id: str | None = None, ): """Create a new run for a given graph, starting at the specified node. Typically, you'll use [`Graph.iter`][pydantic_graph.graph.Graph.iter] rather than calling this directly. Args: graph: The [`Graph`][pydantic_graph.graph.Graph] to run. start_node: The node where execution will begin. persistence: State persistence interface. state: A shared state object or primitive (like a counter, dataclass, etc.) that is available to all nodes via `ctx.state`. deps: Optional dependencies that each node can access via `ctx.deps`, e.g. database connections, configuration, or logging clients. traceparent: The traceparent for the span used for the graph run. snapshot_id: The ID of the snapshot the node came from. """ self.graph = graph self.persistence = persistence self._snapshot_id: str | None = snapshot_id self.state = state self.deps = deps self.__traceparent = traceparent self._next_node: BaseNode[StateT, DepsT, RunEndT] | End[RunEndT] = start_node self._is_started: bool = False @overload def _traceparent(self, *, required: typing_extensions.Literal[False]) -> str | None: ... @overload def _traceparent(self) -> str: ... def _traceparent(self, *, required: bool = True) -> str | None: if self.__traceparent is None and required: # pragma: no cover raise exceptions.GraphRuntimeError('No span was created for this graph run') return self.__traceparent @property def next_node(self) -> BaseNode[StateT, DepsT, RunEndT] | End[RunEndT]: """The next node that will be run in the graph. This is the next node that will be used during async iteration, or if a node is not passed to `self.next(...)`. """ return self._next_node @property def result(self) -> GraphRunResult[StateT, RunEndT] | None: """The final result of the graph run if the run is completed, otherwise `None`.""" if not isinstance(self._next_node, End): return None # The GraphRun has not finished running return GraphRunResult[StateT, RunEndT]( self._next_node.data, state=self.state, persistence=self.persistence, traceparent=self._traceparent(required=False), ) async def next( self, node: BaseNode[StateT, DepsT, RunEndT] | None = None ) -> BaseNode[StateT, DepsT, RunEndT] | End[RunEndT]: """Manually drive the graph run by passing in the node you want to run next. This lets you inspect or mutate the node before continuing execution, or skip certain nodes under dynamic conditions. The graph run should stop when you return an [`End`][pydantic_graph.nodes.End] node. Here's an example of using `next` to drive the graph from [above][pydantic_graph.graph.Graph]: ```py {title="next_never_42.py" noqa="I001" py="3.10" requires="never_42.py"} from copy import deepcopy from pydantic_graph import End from never_42 import Increment, MyState, never_42_graph async def main(): state = MyState(48) async with never_42_graph.iter(Increment(), state=state) as graph_run: next_node = graph_run.next_node # start with the first node node_states = [(next_node, deepcopy(graph_run.state))] while not isinstance(next_node, End): if graph_run.state.number == 50: graph_run.state.number = 42 next_node = await graph_run.next(next_node) node_states.append((next_node, deepcopy(graph_run.state))) print(node_states) ''' [ (Increment(), MyState(number=48)), (Check42(), MyState(number=49)), (End(data=49), MyState(number=49)), ] ''' ``` Args: node: The node to run next in the graph. If not specified, uses `self.next_node`, which is initialized to the `start_node` of the run and updated each time a new node is returned. Returns: The next node returned by the graph logic, or an [`End`][pydantic_graph.nodes.End] node if the run has completed. """ if node is None: # This cast is necessary because self._next_node could be an `End`. You'll get a runtime error if that's # the case, but if it is, the only way to get there would be to have tried calling next manually after # the run finished. Either way, maybe it would be better to not do this cast... node = cast(BaseNode[StateT, DepsT, RunEndT], self._next_node) node_snapshot_id = node.get_snapshot_id() else: node_snapshot_id = node.get_snapshot_id() if node_snapshot_id != self._snapshot_id: await self.persistence.snapshot_node_if_new(node_snapshot_id, self.state, node) self._snapshot_id = node_snapshot_id if not isinstance(node, BaseNode): # While technically this is not compatible with the documented method signature, it's an easy mistake to # make, and we should eagerly provide a more helpful error message than you'd get otherwise. raise TypeError(f'`next` must be called with a `BaseNode` instance, got {node!r}.') node_id = node.get_node_id() if node_id not in self.graph.node_defs: raise exceptions.GraphRuntimeError(f'Node `{node}` is not in the graph.') with ExitStack() as stack: if self.graph.auto_instrument: stack.enter_context(_logfire.span('run node {node_id}', node_id=node_id, node=node)) async with self.persistence.record_run(node_snapshot_id): ctx = GraphRunContext(self.state, self.deps) self._next_node = await node.run(ctx) if isinstance(self._next_node, End): self._snapshot_id = self._next_node.get_snapshot_id() await self.persistence.snapshot_end(self.state, self._next_node) elif isinstance(self._next_node, BaseNode): self._snapshot_id = self._next_node.get_snapshot_id() await self.persistence.snapshot_node(self.state, self._next_node) else: raise exceptions.GraphRuntimeError( f'Invalid node return type: `{type(self._next_node).__name__}`. Expected `BaseNode` or `End`.' ) return self._next_node def __aiter__(self) -> AsyncIterator[BaseNode[StateT, DepsT, RunEndT] | End[RunEndT]]: return self async def __anext__(self) -> BaseNode[StateT, DepsT, RunEndT] | End[RunEndT]: """Use the last returned node as the input to `Graph.next`.""" if not self._is_started: self._is_started = True return self._next_node if isinstance(self._next_node, End): raise StopAsyncIteration return await self.next(self._next_node) def __repr__(self) -> str: return f'' ```` #### __init__ ```python __init__( *, graph: Graph[StateT, DepsT, RunEndT], start_node: BaseNode[StateT, DepsT, RunEndT], persistence: BaseStatePersistence[StateT, RunEndT], state: StateT, deps: DepsT, traceparent: str | None, snapshot_id: str | None = None ) ``` Create a new run for a given graph, starting at the specified node. Typically, you'll use Graph.iter rather than calling this directly. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `graph` | `Graph[StateT, DepsT, RunEndT]` | The Graph to run. | *required* | | `start_node` | `BaseNode[StateT, DepsT, RunEndT]` | The node where execution will begin. | *required* | | `persistence` | `BaseStatePersistence[StateT, RunEndT]` | State persistence interface. | *required* | | `state` | `StateT` | A shared state object or primitive (like a counter, dataclass, etc.) that is available to all nodes via ctx.state. | *required* | | `deps` | `DepsT` | Optional dependencies that each node can access via ctx.deps, e.g. database connections, configuration, or logging clients. | *required* | | `traceparent` | `str | None` | The traceparent for the span used for the graph run. | *required* | | `snapshot_id` | `str | None` | The ID of the snapshot the node came from. | `None` | Source code in `pydantic_graph/pydantic_graph/graph.py` ```python def __init__( self, *, graph: Graph[StateT, DepsT, RunEndT], start_node: BaseNode[StateT, DepsT, RunEndT], persistence: BaseStatePersistence[StateT, RunEndT], state: StateT, deps: DepsT, traceparent: str | None, snapshot_id: str | None = None, ): """Create a new run for a given graph, starting at the specified node. Typically, you'll use [`Graph.iter`][pydantic_graph.graph.Graph.iter] rather than calling this directly. Args: graph: The [`Graph`][pydantic_graph.graph.Graph] to run. start_node: The node where execution will begin. persistence: State persistence interface. state: A shared state object or primitive (like a counter, dataclass, etc.) that is available to all nodes via `ctx.state`. deps: Optional dependencies that each node can access via `ctx.deps`, e.g. database connections, configuration, or logging clients. traceparent: The traceparent for the span used for the graph run. snapshot_id: The ID of the snapshot the node came from. """ self.graph = graph self.persistence = persistence self._snapshot_id: str | None = snapshot_id self.state = state self.deps = deps self.__traceparent = traceparent self._next_node: BaseNode[StateT, DepsT, RunEndT] | End[RunEndT] = start_node self._is_started: bool = False ``` #### next_node ```python next_node: BaseNode[StateT, DepsT, RunEndT] | End[RunEndT] ``` The next node that will be run in the graph. This is the next node that will be used during async iteration, or if a node is not passed to `self.next(...)`. #### result ```python result: GraphRunResult[StateT, RunEndT] | None ``` The final result of the graph run if the run is completed, otherwise `None`. #### next ```python next( node: BaseNode[StateT, DepsT, RunEndT] | None = None, ) -> BaseNode[StateT, DepsT, RunEndT] | End[RunEndT] ``` Manually drive the graph run by passing in the node you want to run next. This lets you inspect or mutate the node before continuing execution, or skip certain nodes under dynamic conditions. The graph run should stop when you return an End node. Here's an example of using `next` to drive the graph from above: next_never_42.py ```py from copy import deepcopy from pydantic_graph import End from never_42 import Increment, MyState, never_42_graph async def main(): state = MyState(48) async with never_42_graph.iter(Increment(), state=state) as graph_run: next_node = graph_run.next_node # start with the first node node_states = [(next_node, deepcopy(graph_run.state))] while not isinstance(next_node, End): if graph_run.state.number == 50: graph_run.state.number = 42 next_node = await graph_run.next(next_node) node_states.append((next_node, deepcopy(graph_run.state))) print(node_states) ''' [ (Increment(), MyState(number=48)), (Check42(), MyState(number=49)), (End(data=49), MyState(number=49)), ] ''' ``` Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `node` | `BaseNode[StateT, DepsT, RunEndT] | None` | The node to run next in the graph. If not specified, uses self.next_node, which is initialized to the start_node of the run and updated each time a new node is returned. | `None` | Returns: | Type | Description | | --- | --- | | `BaseNode[StateT, DepsT, RunEndT] | End[RunEndT]` | The next node returned by the graph logic, or an End node if | | `BaseNode[StateT, DepsT, RunEndT] | End[RunEndT]` | the run has completed. | Source code in `pydantic_graph/pydantic_graph/graph.py` ````python async def next( self, node: BaseNode[StateT, DepsT, RunEndT] | None = None ) -> BaseNode[StateT, DepsT, RunEndT] | End[RunEndT]: """Manually drive the graph run by passing in the node you want to run next. This lets you inspect or mutate the node before continuing execution, or skip certain nodes under dynamic conditions. The graph run should stop when you return an [`End`][pydantic_graph.nodes.End] node. Here's an example of using `next` to drive the graph from [above][pydantic_graph.graph.Graph]: ```py {title="next_never_42.py" noqa="I001" py="3.10" requires="never_42.py"} from copy import deepcopy from pydantic_graph import End from never_42 import Increment, MyState, never_42_graph async def main(): state = MyState(48) async with never_42_graph.iter(Increment(), state=state) as graph_run: next_node = graph_run.next_node # start with the first node node_states = [(next_node, deepcopy(graph_run.state))] while not isinstance(next_node, End): if graph_run.state.number == 50: graph_run.state.number = 42 next_node = await graph_run.next(next_node) node_states.append((next_node, deepcopy(graph_run.state))) print(node_states) ''' [ (Increment(), MyState(number=48)), (Check42(), MyState(number=49)), (End(data=49), MyState(number=49)), ] ''' ``` Args: node: The node to run next in the graph. If not specified, uses `self.next_node`, which is initialized to the `start_node` of the run and updated each time a new node is returned. Returns: The next node returned by the graph logic, or an [`End`][pydantic_graph.nodes.End] node if the run has completed. """ if node is None: # This cast is necessary because self._next_node could be an `End`. You'll get a runtime error if that's # the case, but if it is, the only way to get there would be to have tried calling next manually after # the run finished. Either way, maybe it would be better to not do this cast... node = cast(BaseNode[StateT, DepsT, RunEndT], self._next_node) node_snapshot_id = node.get_snapshot_id() else: node_snapshot_id = node.get_snapshot_id() if node_snapshot_id != self._snapshot_id: await self.persistence.snapshot_node_if_new(node_snapshot_id, self.state, node) self._snapshot_id = node_snapshot_id if not isinstance(node, BaseNode): # While technically this is not compatible with the documented method signature, it's an easy mistake to # make, and we should eagerly provide a more helpful error message than you'd get otherwise. raise TypeError(f'`next` must be called with a `BaseNode` instance, got {node!r}.') node_id = node.get_node_id() if node_id not in self.graph.node_defs: raise exceptions.GraphRuntimeError(f'Node `{node}` is not in the graph.') with ExitStack() as stack: if self.graph.auto_instrument: stack.enter_context(_logfire.span('run node {node_id}', node_id=node_id, node=node)) async with self.persistence.record_run(node_snapshot_id): ctx = GraphRunContext(self.state, self.deps) self._next_node = await node.run(ctx) if isinstance(self._next_node, End): self._snapshot_id = self._next_node.get_snapshot_id() await self.persistence.snapshot_end(self.state, self._next_node) elif isinstance(self._next_node, BaseNode): self._snapshot_id = self._next_node.get_snapshot_id() await self.persistence.snapshot_node(self.state, self._next_node) else: raise exceptions.GraphRuntimeError( f'Invalid node return type: `{type(self._next_node).__name__}`. Expected `BaseNode` or `End`.' ) return self._next_node ```` #### __anext__ ```python __anext__() -> ( BaseNode[StateT, DepsT, RunEndT] | End[RunEndT] ) ``` Use the last returned node as the input to `Graph.next`. Source code in `pydantic_graph/pydantic_graph/graph.py` ```python async def __anext__(self) -> BaseNode[StateT, DepsT, RunEndT] | End[RunEndT]: """Use the last returned node as the input to `Graph.next`.""" if not self._is_started: self._is_started = True return self._next_node if isinstance(self._next_node, End): raise StopAsyncIteration return await self.next(self._next_node) ``` ### GraphRunResult Bases: `Generic[StateT, RunEndT]` The final result of running a graph. Source code in `pydantic_graph/pydantic_graph/graph.py` ```python @dataclass(init=False) class GraphRunResult(Generic[StateT, RunEndT]): """The final result of running a graph.""" output: RunEndT state: StateT persistence: BaseStatePersistence[StateT, RunEndT] = field(repr=False) def __init__( self, output: RunEndT, state: StateT, persistence: BaseStatePersistence[StateT, RunEndT], traceparent: str | None = None, ): self.output = output self.state = state self.persistence = persistence self.__traceparent = traceparent @overload def _traceparent(self, *, required: typing_extensions.Literal[False]) -> str | None: ... @overload def _traceparent(self) -> str: ... def _traceparent(self, *, required: bool = True) -> str | None: # pragma: no cover if self.__traceparent is None and required: raise exceptions.GraphRuntimeError('No span was created for this graph run.') return self.__traceparent ``` # `pydantic_graph.mermaid` ### DEFAULT_HIGHLIGHT_CSS ```python DEFAULT_HIGHLIGHT_CSS = 'fill:#fdff32' ``` The default CSS to use for highlighting nodes. ### StateDiagramDirection ```python StateDiagramDirection = Literal['TB', 'LR', 'RL', 'BT'] ``` Used to specify the direction of the state diagram generated by mermaid. - `'TB'`: Top to bottom, this is the default for mermaid charts. - `'LR'`: Left to right - `'RL'`: Right to left - `'BT'`: Bottom to top ### generate_code ```python generate_code( graph: Graph[Any, Any, Any], /, *, start_node: ( Sequence[NodeIdent] | NodeIdent | None ) = None, highlighted_nodes: ( Sequence[NodeIdent] | NodeIdent | None ) = None, highlight_css: str = DEFAULT_HIGHLIGHT_CSS, title: str | None = None, edge_labels: bool = True, notes: bool = True, direction: StateDiagramDirection | None, ) -> str ``` Generate [Mermaid state diagram](https://mermaid.js.org/syntax/stateDiagram.html) code for a graph. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `graph` | `Graph[Any, Any, Any]` | The graph to generate the image for. | *required* | | `start_node` | `Sequence[NodeIdent] | NodeIdent | None` | Identifiers of nodes that start the graph. | `None` | | `highlighted_nodes` | `Sequence[NodeIdent] | NodeIdent | None` | Identifiers of nodes to highlight. | `None` | | `highlight_css` | `str` | CSS to use for highlighting nodes. | `DEFAULT_HIGHLIGHT_CSS` | | `title` | `str | None` | The title of the diagram. | `None` | | `edge_labels` | `bool` | Whether to include edge labels in the diagram. | `True` | | `notes` | `bool` | Whether to include notes in the diagram. | `True` | | `direction` | `StateDiagramDirection | None` | The direction of flow. | *required* | Returns: | Type | Description | | --- | --- | | `str` | The Mermaid code for the graph. | Source code in `pydantic_graph/pydantic_graph/mermaid.py` ```python def generate_code( # noqa: C901 graph: Graph[Any, Any, Any], /, *, start_node: Sequence[NodeIdent] | NodeIdent | None = None, highlighted_nodes: Sequence[NodeIdent] | NodeIdent | None = None, highlight_css: str = DEFAULT_HIGHLIGHT_CSS, title: str | None = None, edge_labels: bool = True, notes: bool = True, direction: StateDiagramDirection | None, ) -> str: """Generate [Mermaid state diagram](https://mermaid.js.org/syntax/stateDiagram.html) code for a graph. Args: graph: The graph to generate the image for. start_node: Identifiers of nodes that start the graph. highlighted_nodes: Identifiers of nodes to highlight. highlight_css: CSS to use for highlighting nodes. title: The title of the diagram. edge_labels: Whether to include edge labels in the diagram. notes: Whether to include notes in the diagram. direction: The direction of flow. Returns: The Mermaid code for the graph. """ start_node_ids = set(_node_ids(start_node or ())) for node_id in start_node_ids: if node_id not in graph.node_defs: raise LookupError(f'Start node "{node_id}" is not in the graph.') lines: list[str] = [] if title: lines = ['---', f'title: {title}', '---'] lines.append('stateDiagram-v2') if direction is not None: lines.append(f' direction {direction}') for node_id, node_def in graph.node_defs.items(): # we use round brackets (rounded box) for nodes other than the start and end if node_id in start_node_ids: lines.append(f' [*] --> {node_id}') if node_def.returns_base_node: for next_node_id in graph.node_defs: lines.append(f' {node_id} --> {next_node_id}') else: for next_node_id, edge in node_def.next_node_edges.items(): line = f' {node_id} --> {next_node_id}' if edge_labels and edge.label: line += f': {edge.label}' lines.append(line) if end_edge := node_def.end_edge: line = f' {node_id} --> [*]' if edge_labels and end_edge.label: line += f': {end_edge.label}' lines.append(line) if notes and node_def.note: lines.append(f' note right of {node_id}') # mermaid doesn't like multiple paragraphs in a note, and shows if so clean_docs = re.sub('\n{2,}', '\n', node_def.note) lines.append(indent(clean_docs, ' ')) lines.append(' end note') if highlighted_nodes: lines.append('') lines.append(f'classDef highlighted {highlight_css}') for node_id in _node_ids(highlighted_nodes): if node_id not in graph.node_defs: raise LookupError(f'Highlighted node "{node_id}" is not in the graph.') lines.append(f'class {node_id} highlighted') return '\n'.join(lines) ``` ### request_image ```python request_image( graph: Graph[Any, Any, Any], /, **kwargs: Unpack[MermaidConfig], ) -> bytes ``` Generate an image of a Mermaid diagram using [mermaid.ink](https://mermaid.ink). Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `graph` | `Graph[Any, Any, Any]` | The graph to generate the image for. | *required* | | `**kwargs` | `Unpack[MermaidConfig]` | Additional parameters to configure mermaid chart generation. | `{}` | Returns: | Type | Description | | --- | --- | | `bytes` | The image data. | Source code in `pydantic_graph/pydantic_graph/mermaid.py` ```python def request_image( graph: Graph[Any, Any, Any], /, **kwargs: Unpack[MermaidConfig], ) -> bytes: """Generate an image of a Mermaid diagram using [mermaid.ink](https://mermaid.ink). Args: graph: The graph to generate the image for. **kwargs: Additional parameters to configure mermaid chart generation. Returns: The image data. """ code = generate_code( graph, start_node=kwargs.get('start_node'), highlighted_nodes=kwargs.get('highlighted_nodes'), highlight_css=kwargs.get('highlight_css', DEFAULT_HIGHLIGHT_CSS), title=kwargs.get('title'), edge_labels=kwargs.get('edge_labels', True), notes=kwargs.get('notes', True), direction=kwargs.get('direction'), ) code_base64 = base64.b64encode(code.encode()).decode() params: dict[str, str | float] = {} if kwargs.get('image_type') == 'pdf': url = f'https://mermaid.ink/pdf/{code_base64}' if kwargs.get('pdf_fit'): params['fit'] = '' if kwargs.get('pdf_landscape'): params['landscape'] = '' if pdf_paper := kwargs.get('pdf_paper'): params['paper'] = pdf_paper elif kwargs.get('image_type') == 'svg': url = f'https://mermaid.ink/svg/{code_base64}' else: url = f'https://mermaid.ink/img/{code_base64}' if image_type := kwargs.get('image_type'): params['type'] = image_type if background_color := kwargs.get('background_color'): params['bgColor'] = background_color if theme := kwargs.get('theme'): params['theme'] = theme if width := kwargs.get('width'): params['width'] = width if height := kwargs.get('height'): params['height'] = height if scale := kwargs.get('scale'): params['scale'] = scale httpx_client = kwargs.get('httpx_client') or httpx.Client() response = httpx_client.get(url, params=params) if not response.is_success: raise httpx.HTTPStatusError( f'{response.status_code} error generating image:\n{response.text}', request=response.request, response=response, ) return response.content ``` ### save_image ```python save_image( path: Path | str, graph: Graph[Any, Any, Any], /, **kwargs: Unpack[MermaidConfig], ) -> None ``` Generate an image of a Mermaid diagram using [mermaid.ink](https://mermaid.ink) and save it to a local file. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `path` | `Path | str` | The path to save the image to. | *required* | | `graph` | `Graph[Any, Any, Any]` | The graph to generate the image for. | *required* | | `**kwargs` | `Unpack[MermaidConfig]` | Additional parameters to configure mermaid chart generation. | `{}` | Source code in `pydantic_graph/pydantic_graph/mermaid.py` ```python def save_image( path: Path | str, graph: Graph[Any, Any, Any], /, **kwargs: Unpack[MermaidConfig], ) -> None: """Generate an image of a Mermaid diagram using [mermaid.ink](https://mermaid.ink) and save it to a local file. Args: path: The path to save the image to. graph: The graph to generate the image for. **kwargs: Additional parameters to configure mermaid chart generation. """ if isinstance(path, str): path = Path(path) if 'image_type' not in kwargs: ext = path.suffix.lower()[1:] # no need to check for .jpeg/.jpg, as it is the default if ext in ('png', 'webp', 'svg', 'pdf'): kwargs['image_type'] = ext image_data = request_image(graph, **kwargs) path.write_bytes(image_data) ``` ### MermaidConfig Bases: `TypedDict` Parameters to configure mermaid chart generation. Source code in `pydantic_graph/pydantic_graph/mermaid.py` ```python class MermaidConfig(TypedDict, total=False): """Parameters to configure mermaid chart generation.""" start_node: Sequence[NodeIdent] | NodeIdent """Identifiers of nodes that start the graph.""" highlighted_nodes: Sequence[NodeIdent] | NodeIdent """Identifiers of nodes to highlight.""" highlight_css: str """CSS to use for highlighting nodes.""" title: str | None """The title of the diagram.""" edge_labels: bool """Whether to include edge labels in the diagram.""" notes: bool """Whether to include notes on nodes in the diagram, defaults to true.""" image_type: Literal['jpeg', 'png', 'webp', 'svg', 'pdf'] """The image type to generate. If unspecified, the default behavior is `'jpeg'`.""" pdf_fit: bool """When using image_type='pdf', whether to fit the diagram to the PDF page.""" pdf_landscape: bool """When using image_type='pdf', whether to use landscape orientation for the PDF. This has no effect if using `pdf_fit`. """ pdf_paper: Literal['letter', 'legal', 'tabloid', 'ledger', 'a0', 'a1', 'a2', 'a3', 'a4', 'a5', 'a6'] """When using image_type='pdf', the paper size of the PDF.""" background_color: str """The background color of the diagram. If None, the default transparent background is used. The color value is interpreted as a hexadecimal color code by default (and should not have a leading '#'), but you can also use named colors by prefixing the value with `'!'`. For example, valid choices include `background_color='!white'` or `background_color='FF0000'`. """ theme: Literal['default', 'neutral', 'dark', 'forest'] """The theme of the diagram. Defaults to 'default'.""" width: int """The width of the diagram.""" height: int """The height of the diagram.""" scale: Annotated[float, Ge(1), Le(3)] """The scale of the diagram. The scale must be a number between 1 and 3, and you can only set a scale if one or both of width and height are set. """ httpx_client: httpx.Client """An HTTPX client to use for requests, mostly for testing purposes.""" direction: StateDiagramDirection """The direction of the state diagram.""" ``` #### start_node ```python start_node: Sequence[NodeIdent] | NodeIdent ``` Identifiers of nodes that start the graph. #### highlighted_nodes ```python highlighted_nodes: Sequence[NodeIdent] | NodeIdent ``` Identifiers of nodes to highlight. #### highlight_css ```python highlight_css: str ``` CSS to use for highlighting nodes. #### title ```python title: str | None ``` The title of the diagram. #### edge_labels ```python edge_labels: bool ``` Whether to include edge labels in the diagram. #### notes ```python notes: bool ``` Whether to include notes on nodes in the diagram, defaults to true. #### image_type ```python image_type: Literal['jpeg', 'png', 'webp', 'svg', 'pdf'] ``` The image type to generate. If unspecified, the default behavior is `'jpeg'`. #### pdf_fit ```python pdf_fit: bool ``` When using image_type='pdf', whether to fit the diagram to the PDF page. #### pdf_landscape ```python pdf_landscape: bool ``` When using image_type='pdf', whether to use landscape orientation for the PDF. This has no effect if using `pdf_fit`. #### pdf_paper ```python pdf_paper: Literal[ "letter", "legal", "tabloid", "ledger", "a0", "a1", "a2", "a3", "a4", "a5", "a6", ] ``` When using image_type='pdf', the paper size of the PDF. #### background_color ```python background_color: str ``` The background color of the diagram. If None, the default transparent background is used. The color value is interpreted as a hexadecimal color code by default (and should not have a leading '#'), but you can also use named colors by prefixing the value with `'!'`. For example, valid choices include `background_color='!white'` or `background_color='FF0000'`. #### theme ```python theme: Literal['default', 'neutral', 'dark', 'forest'] ``` The theme of the diagram. Defaults to 'default'. #### width ```python width: int ``` The width of the diagram. #### height ```python height: int ``` The height of the diagram. #### scale ```python scale: Annotated[float, Ge(1), Le(3)] ``` The scale of the diagram. The scale must be a number between 1 and 3, and you can only set a scale if one or both of width and height are set. #### httpx_client ```python httpx_client: Client ``` An HTTPX client to use for requests, mostly for testing purposes. #### direction ```python direction: StateDiagramDirection ``` The direction of the state diagram. ### NodeIdent ```python NodeIdent: TypeAlias = ( "type[BaseNode[Any, Any, Any]] | BaseNode[Any, Any, Any] | str" ) ``` A type alias for a node identifier. This can be: - A node instance (instance of a subclass of BaseNode). - A node class (subclass of BaseNode). - A string representing the node ID. # `pydantic_graph.nodes` ### StateT ```python StateT = TypeVar('StateT', default=None) ``` Type variable for the state in a graph. ### GraphRunContext Bases: `Generic[StateT, DepsT]` Context for a graph. Source code in `pydantic_graph/pydantic_graph/nodes.py` ```python @dataclass class GraphRunContext(Generic[StateT, DepsT]): """Context for a graph.""" # TODO: Can we get rid of this struct and just pass both these things around..? state: StateT """The state of the graph.""" deps: DepsT """Dependencies for the graph.""" ``` #### state ```python state: StateT ``` The state of the graph. #### deps ```python deps: DepsT ``` Dependencies for the graph. ### BaseNode Bases: `ABC`, `Generic[StateT, DepsT, NodeRunEndT]` Base class for a node. Source code in `pydantic_graph/pydantic_graph/nodes.py` ```python class BaseNode(ABC, Generic[StateT, DepsT, NodeRunEndT]): """Base class for a node.""" docstring_notes: ClassVar[bool] = False """Set to `True` to generate mermaid diagram notes from the class's docstring. While this can add valuable information to the diagram, it can make diagrams harder to view, hence it is disabled by default. You can also customise notes overriding the [`get_note`][pydantic_graph.nodes.BaseNode.get_note] method. """ @abstractmethod async def run(self, ctx: GraphRunContext[StateT, DepsT]) -> BaseNode[StateT, DepsT, Any] | End[NodeRunEndT]: """Run the node. This is an abstract method that must be implemented by subclasses. !!! note "Return types used at runtime" The return type of this method are read by `pydantic_graph` at runtime and used to define which nodes can be called next in the graph. This is displayed in [mermaid diagrams](mermaid.md) and enforced when running the graph. Args: ctx: The graph context. Returns: The next node to run or [`End`][pydantic_graph.nodes.End] to signal the end of the graph. """ ... def get_snapshot_id(self) -> str: if snapshot_id := getattr(self, '__snapshot_id', None): return snapshot_id else: self.__dict__['__snapshot_id'] = snapshot_id = generate_snapshot_id(self.get_node_id()) return snapshot_id def set_snapshot_id(self, snapshot_id: str) -> None: self.__dict__['__snapshot_id'] = snapshot_id @classmethod @cache def get_node_id(cls) -> str: """Get the ID of the node.""" return cls.__name__ @classmethod def get_note(cls) -> str | None: """Get a note about the node to render on mermaid charts. By default, this returns a note only if [`docstring_notes`][pydantic_graph.nodes.BaseNode.docstring_notes] is `True`. You can override this method to customise the node notes. """ if not cls.docstring_notes: return None docstring = cls.__doc__ # dataclasses get an automatic docstring which is just their signature, we don't want that if docstring and is_dataclass(cls) and docstring.startswith(f'{cls.__name__}('): docstring = None # pragma: no cover if docstring: # pragma: no branch # remove indentation from docstring import inspect docstring = inspect.cleandoc(docstring) return docstring @classmethod def get_node_def(cls, local_ns: dict[str, Any] | None) -> NodeDef[StateT, DepsT, NodeRunEndT]: """Get the node definition.""" type_hints = get_type_hints(cls.run, localns=local_ns, include_extras=True) try: return_hint = type_hints['return'] except KeyError as e: raise exceptions.GraphSetupError(f'Node {cls} is missing a return type hint on its `run` method') from e next_node_edges: dict[str, Edge] = {} end_edge: Edge | None = None returns_base_node: bool = False for return_type in _utils.get_union_args(return_hint): return_type, annotations = _utils.unpack_annotated(return_type) edge = next((a for a in annotations if isinstance(a, Edge)), Edge(None)) return_type_origin = get_origin(return_type) or return_type if return_type_origin is End: end_edge = edge elif return_type_origin is BaseNode: # TODO: Should we disallow this? returns_base_node = True elif issubclass(return_type_origin, BaseNode): next_node_edges[return_type.get_node_id()] = edge else: raise exceptions.GraphSetupError(f'Invalid return type: {return_type}') return NodeDef( cls, cls.get_node_id(), cls.get_note(), next_node_edges, end_edge, returns_base_node, ) def deep_copy(self) -> Self: """Returns a deep copy of the node.""" return copy.deepcopy(self) ``` #### docstring_notes ```python docstring_notes: bool = False ``` Set to `True` to generate mermaid diagram notes from the class's docstring. While this can add valuable information to the diagram, it can make diagrams harder to view, hence it is disabled by default. You can also customise notes overriding the get_note method. #### run ```python run( ctx: GraphRunContext[StateT, DepsT], ) -> BaseNode[StateT, DepsT, Any] | End[NodeRunEndT] ``` Run the node. This is an abstract method that must be implemented by subclasses. Return types used at runtime The return type of this method are read by `pydantic_graph` at runtime and used to define which nodes can be called next in the graph. This is displayed in [mermaid diagrams](../mermaid/) and enforced when running the graph. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `ctx` | `GraphRunContext[StateT, DepsT]` | The graph context. | *required* | Returns: | Type | Description | | --- | --- | | `BaseNode[StateT, DepsT, Any] | End[NodeRunEndT]` | The next node to run or End to signal the end of the graph. | Source code in `pydantic_graph/pydantic_graph/nodes.py` ```python @abstractmethod async def run(self, ctx: GraphRunContext[StateT, DepsT]) -> BaseNode[StateT, DepsT, Any] | End[NodeRunEndT]: """Run the node. This is an abstract method that must be implemented by subclasses. !!! note "Return types used at runtime" The return type of this method are read by `pydantic_graph` at runtime and used to define which nodes can be called next in the graph. This is displayed in [mermaid diagrams](mermaid.md) and enforced when running the graph. Args: ctx: The graph context. Returns: The next node to run or [`End`][pydantic_graph.nodes.End] to signal the end of the graph. """ ... ``` #### get_node_id ```python get_node_id() -> str ``` Get the ID of the node. Source code in `pydantic_graph/pydantic_graph/nodes.py` ```python @classmethod @cache def get_node_id(cls) -> str: """Get the ID of the node.""" return cls.__name__ ``` #### get_note ```python get_note() -> str | None ``` Get a note about the node to render on mermaid charts. By default, this returns a note only if docstring_notes is `True`. You can override this method to customise the node notes. Source code in `pydantic_graph/pydantic_graph/nodes.py` ```python @classmethod def get_note(cls) -> str | None: """Get a note about the node to render on mermaid charts. By default, this returns a note only if [`docstring_notes`][pydantic_graph.nodes.BaseNode.docstring_notes] is `True`. You can override this method to customise the node notes. """ if not cls.docstring_notes: return None docstring = cls.__doc__ # dataclasses get an automatic docstring which is just their signature, we don't want that if docstring and is_dataclass(cls) and docstring.startswith(f'{cls.__name__}('): docstring = None # pragma: no cover if docstring: # pragma: no branch # remove indentation from docstring import inspect docstring = inspect.cleandoc(docstring) return docstring ``` #### get_node_def ```python get_node_def( local_ns: dict[str, Any] | None, ) -> NodeDef[StateT, DepsT, NodeRunEndT] ``` Get the node definition. Source code in `pydantic_graph/pydantic_graph/nodes.py` ```python @classmethod def get_node_def(cls, local_ns: dict[str, Any] | None) -> NodeDef[StateT, DepsT, NodeRunEndT]: """Get the node definition.""" type_hints = get_type_hints(cls.run, localns=local_ns, include_extras=True) try: return_hint = type_hints['return'] except KeyError as e: raise exceptions.GraphSetupError(f'Node {cls} is missing a return type hint on its `run` method') from e next_node_edges: dict[str, Edge] = {} end_edge: Edge | None = None returns_base_node: bool = False for return_type in _utils.get_union_args(return_hint): return_type, annotations = _utils.unpack_annotated(return_type) edge = next((a for a in annotations if isinstance(a, Edge)), Edge(None)) return_type_origin = get_origin(return_type) or return_type if return_type_origin is End: end_edge = edge elif return_type_origin is BaseNode: # TODO: Should we disallow this? returns_base_node = True elif issubclass(return_type_origin, BaseNode): next_node_edges[return_type.get_node_id()] = edge else: raise exceptions.GraphSetupError(f'Invalid return type: {return_type}') return NodeDef( cls, cls.get_node_id(), cls.get_note(), next_node_edges, end_edge, returns_base_node, ) ``` #### deep_copy ```python deep_copy() -> Self ``` Returns a deep copy of the node. Source code in `pydantic_graph/pydantic_graph/nodes.py` ```python def deep_copy(self) -> Self: """Returns a deep copy of the node.""" return copy.deepcopy(self) ``` ### End Bases: `Generic[RunEndT]` Type to return from a node to signal the end of the graph. Source code in `pydantic_graph/pydantic_graph/nodes.py` ```python @dataclass class End(Generic[RunEndT]): """Type to return from a node to signal the end of the graph.""" data: RunEndT """Data to return from the graph.""" def deep_copy_data(self) -> End[RunEndT]: """Returns a deep copy of the end of the run.""" if self.data is None: return self else: end = End(copy.deepcopy(self.data)) end.set_snapshot_id(self.get_snapshot_id()) return end def get_snapshot_id(self) -> str: if snapshot_id := getattr(self, '__snapshot_id', None): return snapshot_id else: self.__dict__['__snapshot_id'] = snapshot_id = generate_snapshot_id('end') return snapshot_id def set_snapshot_id(self, set_id: str) -> None: self.__dict__['__snapshot_id'] = set_id ``` #### data ```python data: RunEndT ``` Data to return from the graph. #### deep_copy_data ```python deep_copy_data() -> End[RunEndT] ``` Returns a deep copy of the end of the run. Source code in `pydantic_graph/pydantic_graph/nodes.py` ```python def deep_copy_data(self) -> End[RunEndT]: """Returns a deep copy of the end of the run.""" if self.data is None: return self else: end = End(copy.deepcopy(self.data)) end.set_snapshot_id(self.get_snapshot_id()) return end ``` ### Edge Annotation to apply a label to an edge in a graph. Source code in `pydantic_graph/pydantic_graph/nodes.py` ```python @dataclass(frozen=True) class Edge: """Annotation to apply a label to an edge in a graph.""" label: str | None """Label for the edge.""" ``` #### label ```python label: str | None ``` Label for the edge. ### DepsT ```python DepsT = TypeVar('DepsT', default=None, contravariant=True) ``` Type variable for the dependencies of a graph and node. ### RunEndT ```python RunEndT = TypeVar('RunEndT', covariant=True, default=None) ``` Covariant type variable for the return type of a graph run. ### NodeRunEndT ```python NodeRunEndT = TypeVar( "NodeRunEndT", covariant=True, default=Never ) ``` Covariant type variable for the return type of a node run. # `pydantic_graph.persistence` ### SnapshotStatus ```python SnapshotStatus = Literal[ "created", "pending", "running", "success", "error" ] ``` The status of a snapshot. - `'created'`: The snapshot has been created but not yet run. - `'pending'`: The snapshot has been retrieved with load_next but not yet run. - `'running'`: The snapshot is currently running. - `'success'`: The snapshot has been run successfully. - `'error'`: The snapshot has been run but an error occurred. ### NodeSnapshot Bases: `Generic[StateT, RunEndT]` History step describing the execution of a node in a graph. Source code in `pydantic_graph/pydantic_graph/persistence/__init__.py` ```python @dataclass class NodeSnapshot(Generic[StateT, RunEndT]): """History step describing the execution of a node in a graph.""" state: StateT """The state of the graph before the node is run.""" node: Annotated[BaseNode[StateT, Any, RunEndT], _utils.CustomNodeSchema()] """The node to run next.""" start_ts: datetime | None = None """The timestamp when the node started running, `None` until the run starts.""" duration: float | None = None """The duration of the node run in seconds, if the node has been run.""" status: SnapshotStatus = 'created' """The status of the snapshot.""" kind: Literal['node'] = 'node' """The kind of history step, can be used as a discriminator when deserializing history.""" id: str = UNSET_SNAPSHOT_ID """Unique ID of the snapshot.""" def __post_init__(self) -> None: if self.id == UNSET_SNAPSHOT_ID: self.id = self.node.get_snapshot_id() ``` #### state ```python state: StateT ``` The state of the graph before the node is run. #### node ```python node: Annotated[ BaseNode[StateT, Any, RunEndT], CustomNodeSchema() ] ``` The node to run next. #### start_ts ```python start_ts: datetime | None = None ``` The timestamp when the node started running, `None` until the run starts. #### duration ```python duration: float | None = None ``` The duration of the node run in seconds, if the node has been run. #### status ```python status: SnapshotStatus = 'created' ``` The status of the snapshot. #### kind ```python kind: Literal['node'] = 'node' ``` The kind of history step, can be used as a discriminator when deserializing history. #### id ```python id: str = UNSET_SNAPSHOT_ID ``` Unique ID of the snapshot. ### EndSnapshot Bases: `Generic[StateT, RunEndT]` History step describing the end of a graph run. Source code in `pydantic_graph/pydantic_graph/persistence/__init__.py` ```python @dataclass class EndSnapshot(Generic[StateT, RunEndT]): """History step describing the end of a graph run.""" state: StateT """The state of the graph at the end of the run.""" result: End[RunEndT] """The result of the graph run.""" ts: datetime = field(default_factory=_utils.now_utc) """The timestamp when the graph run ended.""" kind: Literal['end'] = 'end' """The kind of history step, can be used as a discriminator when deserializing history.""" id: str = UNSET_SNAPSHOT_ID """Unique ID of the snapshot.""" def __post_init__(self) -> None: if self.id == UNSET_SNAPSHOT_ID: self.id = self.node.get_snapshot_id() @property def node(self) -> End[RunEndT]: """Shim to get the [`result`][pydantic_graph.persistence.EndSnapshot.result]. Useful to allow `[snapshot.node for snapshot in persistence.history]`. """ return self.result ``` #### state ```python state: StateT ``` The state of the graph at the end of the run. #### result ```python result: End[RunEndT] ``` The result of the graph run. #### ts ```python ts: datetime = field(default_factory=now_utc) ``` The timestamp when the graph run ended. #### kind ```python kind: Literal['end'] = 'end' ``` The kind of history step, can be used as a discriminator when deserializing history. #### id ```python id: str = UNSET_SNAPSHOT_ID ``` Unique ID of the snapshot. #### node ```python node: End[RunEndT] ``` Shim to get the result. Useful to allow `[snapshot.node for snapshot in persistence.history]`. ### Snapshot ```python Snapshot = Union[ NodeSnapshot[StateT, RunEndT], EndSnapshot[StateT, RunEndT], ] ``` A step in the history of a graph run. Graph.run returns a list of these steps describing the execution of the graph, together with the run return value. ### BaseStatePersistence Bases: `ABC`, `Generic[StateT, RunEndT]` Abstract base class for storing the state of a graph run. Each instance of a `BaseStatePersistence` subclass should be used for a single graph run. Source code in `pydantic_graph/pydantic_graph/persistence/__init__.py` ```python class BaseStatePersistence(ABC, Generic[StateT, RunEndT]): """Abstract base class for storing the state of a graph run. Each instance of a `BaseStatePersistence` subclass should be used for a single graph run. """ @abstractmethod async def snapshot_node(self, state: StateT, next_node: BaseNode[StateT, Any, RunEndT]) -> None: """Snapshot the state of a graph, when the next step is to run a node. This method should add a [`NodeSnapshot`][pydantic_graph.persistence.NodeSnapshot] to persistence. Args: state: The state of the graph. next_node: The next node to run. """ raise NotImplementedError @abstractmethod async def snapshot_node_if_new( self, snapshot_id: str, state: StateT, next_node: BaseNode[StateT, Any, RunEndT] ) -> None: """Snapshot the state of a graph if the snapshot ID doesn't already exist in persistence. This method will generally call [`snapshot_node`][pydantic_graph.persistence.BaseStatePersistence.snapshot_node] but should do so in an atomic way. Args: snapshot_id: The ID of the snapshot to check. state: The state of the graph. next_node: The next node to run. """ raise NotImplementedError @abstractmethod async def snapshot_end(self, state: StateT, end: End[RunEndT]) -> None: """Snapshot the state of a graph when the graph has ended. This method should add an [`EndSnapshot`][pydantic_graph.persistence.EndSnapshot] to persistence. Args: state: The state of the graph. end: data from the end of the run. """ raise NotImplementedError @abstractmethod def record_run(self, snapshot_id: str) -> AbstractAsyncContextManager[None]: """Record the run of the node, or error if the node is already running. Args: snapshot_id: The ID of the snapshot to record. Raises: GraphNodeRunningError: if the node status it not `'created'` or `'pending'`. LookupError: if the snapshot ID is not found in persistence. Returns: An async context manager that records the run of the node. In particular this should set: - [`NodeSnapshot.status`][pydantic_graph.persistence.NodeSnapshot.status] to `'running'` and [`NodeSnapshot.start_ts`][pydantic_graph.persistence.NodeSnapshot.start_ts] when the run starts. - [`NodeSnapshot.status`][pydantic_graph.persistence.NodeSnapshot.status] to `'success'` or `'error'` and [`NodeSnapshot.duration`][pydantic_graph.persistence.NodeSnapshot.duration] when the run finishes. """ raise NotImplementedError @abstractmethod async def load_next(self) -> NodeSnapshot[StateT, RunEndT] | None: """Retrieve a node snapshot with status `'created`' and set its status to `'pending'`. This is used by [`Graph.iter_from_persistence`][pydantic_graph.graph.Graph.iter_from_persistence] to get the next node to run. Returns: The snapshot, or `None` if no snapshot with status `'created`' exists. """ raise NotImplementedError @abstractmethod async def load_all(self) -> list[Snapshot[StateT, RunEndT]]: """Load the entire history of snapshots. `load_all` is not used by pydantic-graph itself, instead it's provided to make it convenient to get all [snapshots][pydantic_graph.persistence.Snapshot] from persistence. Returns: The list of snapshots. """ raise NotImplementedError def set_graph_types(self, graph: Graph[StateT, Any, RunEndT]) -> None: """Set the types of the state and run end from a graph. You generally won't need to customise this method, instead implement [`set_types`][pydantic_graph.persistence.BaseStatePersistence.set_types] and [`should_set_types`][pydantic_graph.persistence.BaseStatePersistence.should_set_types]. """ if self.should_set_types(): with _utils.set_nodes_type_context(graph.get_nodes()): self.set_types(*graph.inferred_types) def should_set_types(self) -> bool: """Whether types need to be set. Implementations should override this method to return `True` when types have not been set if they are needed. """ return False def set_types(self, state_type: type[StateT], run_end_type: type[RunEndT]) -> None: """Set the types of the state and run end. This can be used to create [type adapters][pydantic.TypeAdapter] for serializing and deserializing snapshots, e.g. with [`build_snapshot_list_type_adapter`][pydantic_graph.persistence.build_snapshot_list_type_adapter]. Args: state_type: The state type. run_end_type: The run end type. """ pass ``` #### snapshot_node ```python snapshot_node( state: StateT, next_node: BaseNode[StateT, Any, RunEndT] ) -> None ``` Snapshot the state of a graph, when the next step is to run a node. This method should add a NodeSnapshot to persistence. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `state` | `StateT` | The state of the graph. | *required* | | `next_node` | `BaseNode[StateT, Any, RunEndT]` | The next node to run. | *required* | Source code in `pydantic_graph/pydantic_graph/persistence/__init__.py` ```python @abstractmethod async def snapshot_node(self, state: StateT, next_node: BaseNode[StateT, Any, RunEndT]) -> None: """Snapshot the state of a graph, when the next step is to run a node. This method should add a [`NodeSnapshot`][pydantic_graph.persistence.NodeSnapshot] to persistence. Args: state: The state of the graph. next_node: The next node to run. """ raise NotImplementedError ``` #### snapshot_node_if_new ```python snapshot_node_if_new( snapshot_id: str, state: StateT, next_node: BaseNode[StateT, Any, RunEndT], ) -> None ``` Snapshot the state of a graph if the snapshot ID doesn't already exist in persistence. This method will generally call snapshot_node but should do so in an atomic way. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `snapshot_id` | `str` | The ID of the snapshot to check. | *required* | | `state` | `StateT` | The state of the graph. | *required* | | `next_node` | `BaseNode[StateT, Any, RunEndT]` | The next node to run. | *required* | Source code in `pydantic_graph/pydantic_graph/persistence/__init__.py` ```python @abstractmethod async def snapshot_node_if_new( self, snapshot_id: str, state: StateT, next_node: BaseNode[StateT, Any, RunEndT] ) -> None: """Snapshot the state of a graph if the snapshot ID doesn't already exist in persistence. This method will generally call [`snapshot_node`][pydantic_graph.persistence.BaseStatePersistence.snapshot_node] but should do so in an atomic way. Args: snapshot_id: The ID of the snapshot to check. state: The state of the graph. next_node: The next node to run. """ raise NotImplementedError ``` #### snapshot_end ```python snapshot_end(state: StateT, end: End[RunEndT]) -> None ``` Snapshot the state of a graph when the graph has ended. This method should add an EndSnapshot to persistence. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `state` | `StateT` | The state of the graph. | *required* | | `end` | `End[RunEndT]` | data from the end of the run. | *required* | Source code in `pydantic_graph/pydantic_graph/persistence/__init__.py` ```python @abstractmethod async def snapshot_end(self, state: StateT, end: End[RunEndT]) -> None: """Snapshot the state of a graph when the graph has ended. This method should add an [`EndSnapshot`][pydantic_graph.persistence.EndSnapshot] to persistence. Args: state: The state of the graph. end: data from the end of the run. """ raise NotImplementedError ``` #### record_run ```python record_run( snapshot_id: str, ) -> AbstractAsyncContextManager[None] ``` Record the run of the node, or error if the node is already running. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `snapshot_id` | `str` | The ID of the snapshot to record. | *required* | Raises: | Type | Description | | --- | --- | | `GraphNodeRunningError` | if the node status it not 'created' or 'pending'. | | `LookupError` | if the snapshot ID is not found in persistence. | Returns: | Type | Description | | --- | --- | | `AbstractAsyncContextManager[None]` | An async context manager that records the run of the node. | In particular this should set: - NodeSnapshot.status to `'running'` and NodeSnapshot.start_ts when the run starts. - NodeSnapshot.status to `'success'` or `'error'` and NodeSnapshot.duration when the run finishes. Source code in `pydantic_graph/pydantic_graph/persistence/__init__.py` ```python @abstractmethod def record_run(self, snapshot_id: str) -> AbstractAsyncContextManager[None]: """Record the run of the node, or error if the node is already running. Args: snapshot_id: The ID of the snapshot to record. Raises: GraphNodeRunningError: if the node status it not `'created'` or `'pending'`. LookupError: if the snapshot ID is not found in persistence. Returns: An async context manager that records the run of the node. In particular this should set: - [`NodeSnapshot.status`][pydantic_graph.persistence.NodeSnapshot.status] to `'running'` and [`NodeSnapshot.start_ts`][pydantic_graph.persistence.NodeSnapshot.start_ts] when the run starts. - [`NodeSnapshot.status`][pydantic_graph.persistence.NodeSnapshot.status] to `'success'` or `'error'` and [`NodeSnapshot.duration`][pydantic_graph.persistence.NodeSnapshot.duration] when the run finishes. """ raise NotImplementedError ``` #### load_next ```python load_next() -> NodeSnapshot[StateT, RunEndT] | None ``` Retrieve a node snapshot with status `'created`' and set its status to `'pending'`. This is used by Graph.iter_from_persistence to get the next node to run. Returns: The snapshot, or `None` if no snapshot with status `'created`' exists. Source code in `pydantic_graph/pydantic_graph/persistence/__init__.py` ```python @abstractmethod async def load_next(self) -> NodeSnapshot[StateT, RunEndT] | None: """Retrieve a node snapshot with status `'created`' and set its status to `'pending'`. This is used by [`Graph.iter_from_persistence`][pydantic_graph.graph.Graph.iter_from_persistence] to get the next node to run. Returns: The snapshot, or `None` if no snapshot with status `'created`' exists. """ raise NotImplementedError ``` #### load_all ```python load_all() -> list[Snapshot[StateT, RunEndT]] ``` Load the entire history of snapshots. `load_all` is not used by pydantic-graph itself, instead it's provided to make it convenient to get all snapshots from persistence. Returns: The list of snapshots. Source code in `pydantic_graph/pydantic_graph/persistence/__init__.py` ```python @abstractmethod async def load_all(self) -> list[Snapshot[StateT, RunEndT]]: """Load the entire history of snapshots. `load_all` is not used by pydantic-graph itself, instead it's provided to make it convenient to get all [snapshots][pydantic_graph.persistence.Snapshot] from persistence. Returns: The list of snapshots. """ raise NotImplementedError ``` #### set_graph_types ```python set_graph_types(graph: Graph[StateT, Any, RunEndT]) -> None ``` Set the types of the state and run end from a graph. You generally won't need to customise this method, instead implement set_types and should_set_types. Source code in `pydantic_graph/pydantic_graph/persistence/__init__.py` ```python def set_graph_types(self, graph: Graph[StateT, Any, RunEndT]) -> None: """Set the types of the state and run end from a graph. You generally won't need to customise this method, instead implement [`set_types`][pydantic_graph.persistence.BaseStatePersistence.set_types] and [`should_set_types`][pydantic_graph.persistence.BaseStatePersistence.should_set_types]. """ if self.should_set_types(): with _utils.set_nodes_type_context(graph.get_nodes()): self.set_types(*graph.inferred_types) ``` #### should_set_types ```python should_set_types() -> bool ``` Whether types need to be set. Implementations should override this method to return `True` when types have not been set if they are needed. Source code in `pydantic_graph/pydantic_graph/persistence/__init__.py` ```python def should_set_types(self) -> bool: """Whether types need to be set. Implementations should override this method to return `True` when types have not been set if they are needed. """ return False ``` #### set_types ```python set_types( state_type: type[StateT], run_end_type: type[RunEndT] ) -> None ``` Set the types of the state and run end. This can be used to create type adapters for serializing and deserializing snapshots, e.g. with build_snapshot_list_type_adapter. Parameters: | Name | Type | Description | Default | | --- | --- | --- | --- | | `state_type` | `type[StateT]` | The state type. | *required* | | `run_end_type` | `type[RunEndT]` | The run end type. | *required* | Source code in `pydantic_graph/pydantic_graph/persistence/__init__.py` ```python def set_types(self, state_type: type[StateT], run_end_type: type[RunEndT]) -> None: """Set the types of the state and run end. This can be used to create [type adapters][pydantic.TypeAdapter] for serializing and deserializing snapshots, e.g. with [`build_snapshot_list_type_adapter`][pydantic_graph.persistence.build_snapshot_list_type_adapter]. Args: state_type: The state type. run_end_type: The run end type. """ pass ``` ### build_snapshot_list_type_adapter ```python build_snapshot_list_type_adapter( state_t: type[StateT], run_end_t: type[RunEndT] ) -> TypeAdapter[list[Snapshot[StateT, RunEndT]]] ``` Build a type adapter for a list of snapshots. This method should be called from within set_types where context variables will be set such that Pydantic can create a schema for NodeSnapshot.node. Source code in `pydantic_graph/pydantic_graph/persistence/__init__.py` ```python def build_snapshot_list_type_adapter( state_t: type[StateT], run_end_t: type[RunEndT] ) -> pydantic.TypeAdapter[list[Snapshot[StateT, RunEndT]]]: """Build a type adapter for a list of snapshots. This method should be called from within [`set_types`][pydantic_graph.persistence.BaseStatePersistence.set_types] where context variables will be set such that Pydantic can create a schema for [`NodeSnapshot.node`][pydantic_graph.persistence.NodeSnapshot.node]. """ return pydantic.TypeAdapter(list[Annotated[Snapshot[state_t, run_end_t], pydantic.Discriminator('kind')]]) ``` In memory state persistence. This module provides simple in memory state persistence for graphs. ### SimpleStatePersistence Bases: `BaseStatePersistence[StateT, RunEndT]` Simple in memory state persistence that just hold the latest snapshot. If no state persistence implementation is provided when running a graph, this is used by default. Source code in `pydantic_graph/pydantic_graph/persistence/in_mem.py` ```python @dataclass class SimpleStatePersistence(BaseStatePersistence[StateT, RunEndT]): """Simple in memory state persistence that just hold the latest snapshot. If no state persistence implementation is provided when running a graph, this is used by default. """ last_snapshot: Snapshot[StateT, RunEndT] | None = None """The last snapshot.""" async def snapshot_node(self, state: StateT, next_node: BaseNode[StateT, Any, RunEndT]) -> None: self.last_snapshot = NodeSnapshot(state=state, node=next_node) async def snapshot_node_if_new( self, snapshot_id: str, state: StateT, next_node: BaseNode[StateT, Any, RunEndT] ) -> None: if self.last_snapshot and self.last_snapshot.id == snapshot_id: return # pragma: no cover else: await self.snapshot_node(state, next_node) async def snapshot_end(self, state: StateT, end: End[RunEndT]) -> None: self.last_snapshot = EndSnapshot(state=state, result=end) @asynccontextmanager async def record_run(self, snapshot_id: str) -> AsyncIterator[None]: if self.last_snapshot is None or snapshot_id != self.last_snapshot.id: raise LookupError(f'No snapshot found with id={snapshot_id!r}') assert isinstance(self.last_snapshot, NodeSnapshot), 'Only NodeSnapshot can be recorded' exceptions.GraphNodeStatusError.check(self.last_snapshot.status) self.last_snapshot.status = 'running' self.last_snapshot.start_ts = _utils.now_utc() start = perf_counter() try: yield except Exception: self.last_snapshot.duration = perf_counter() - start self.last_snapshot.status = 'error' raise else: self.last_snapshot.duration = perf_counter() - start self.last_snapshot.status = 'success' async def load_next(self) -> NodeSnapshot[StateT, RunEndT] | None: if isinstance(self.last_snapshot, NodeSnapshot) and self.last_snapshot.status == 'created': self.last_snapshot.status = 'pending' return self.last_snapshot async def load_all(self) -> list[Snapshot[StateT, RunEndT]]: raise NotImplementedError('load is not supported for SimpleStatePersistence') ``` #### last_snapshot ```python last_snapshot: Snapshot[StateT, RunEndT] | None = None ``` The last snapshot. ### FullStatePersistence Bases: `BaseStatePersistence[StateT, RunEndT]` In memory state persistence that hold a list of snapshots. Source code in `pydantic_graph/pydantic_graph/persistence/in_mem.py` ```python @dataclass class FullStatePersistence(BaseStatePersistence[StateT, RunEndT]): """In memory state persistence that hold a list of snapshots.""" deep_copy: bool = True """Whether to deep copy the state and nodes when storing them. Defaults to `True` so even if nodes or state are modified after the snapshot is taken, the persistence history will record the value at the time of the snapshot. """ history: list[Snapshot[StateT, RunEndT]] = field(default_factory=list) """List of snapshots taken during the graph run.""" _snapshots_type_adapter: pydantic.TypeAdapter[list[Snapshot[StateT, RunEndT]]] | None = field( default=None, init=False, repr=False ) async def snapshot_node(self, state: StateT, next_node: BaseNode[StateT, Any, RunEndT]) -> None: snapshot = NodeSnapshot( state=self._prep_state(state), node=next_node.deep_copy() if self.deep_copy else next_node, ) self.history.append(snapshot) async def snapshot_node_if_new( self, snapshot_id: str, state: StateT, next_node: BaseNode[StateT, Any, RunEndT] ) -> None: if not any(s.id == snapshot_id for s in self.history): await self.snapshot_node(state, next_node) async def snapshot_end(self, state: StateT, end: End[RunEndT]) -> None: snapshot = EndSnapshot( state=self._prep_state(state), result=end.deep_copy_data() if self.deep_copy else end, ) self.history.append(snapshot) @asynccontextmanager async def record_run(self, snapshot_id: str) -> AsyncIterator[None]: try: snapshot = next(s for s in self.history if s.id == snapshot_id) except StopIteration as e: raise LookupError(f'No snapshot found with id={snapshot_id!r}') from e assert isinstance(snapshot, NodeSnapshot), 'Only NodeSnapshot can be recorded' exceptions.GraphNodeStatusError.check(snapshot.status) snapshot.status = 'running' snapshot.start_ts = _utils.now_utc() start = perf_counter() try: yield except Exception: snapshot.duration = perf_counter() - start snapshot.status = 'error' raise else: snapshot.duration = perf_counter() - start snapshot.status = 'success' async def load_next(self) -> NodeSnapshot[StateT, RunEndT] | None: if snapshot := next((s for s in self.history if isinstance(s, NodeSnapshot) and s.status == 'created'), None): snapshot.status = 'pending' return snapshot async def load_all(self) -> list[Snapshot[StateT, RunEndT]]: return self.history def should_set_types(self) -> bool: return self._snapshots_type_adapter is None def set_types(self, state_type: type[StateT], run_end_type: type[RunEndT]) -> None: self._snapshots_type_adapter = build_snapshot_list_type_adapter(state_type, run_end_type) def dump_json(self, *, indent: int | None = None) -> bytes: """Dump the history to JSON bytes.""" assert self._snapshots_type_adapter is not None, 'type adapter must be set to use `dump_json`' return self._snapshots_type_adapter.dump_json(self.history, indent=indent) def load_json(self, json_data: str | bytes | bytearray) -> None: """Load the history from JSON.""" assert self._snapshots_type_adapter is not None, 'type adapter must be set to use `load_json`' self.history = self._snapshots_type_adapter.validate_json(json_data) def _prep_state(self, state: StateT) -> StateT: """Prepare state for snapshot, uses [`copy.deepcopy`][copy.deepcopy] by default.""" if not self.deep_copy or state is None: return state else: return copy.deepcopy(state) ``` #### deep_copy ```python deep_copy: bool = True ``` Whether to deep copy the state and nodes when storing them. Defaults to `True` so even if nodes or state are modified after the snapshot is taken, the persistence history will record the value at the time of the snapshot. #### history ```python history: list[Snapshot[StateT, RunEndT]] = field( default_factory=list ) ``` List of snapshots taken during the graph run. #### dump_json ```python dump_json(*, indent: int | None = None) -> bytes ``` Dump the history to JSON bytes. Source code in `pydantic_graph/pydantic_graph/persistence/in_mem.py` ```python def dump_json(self, *, indent: int | None = None) -> bytes: """Dump the history to JSON bytes.""" assert self._snapshots_type_adapter is not None, 'type adapter must be set to use `dump_json`' return self._snapshots_type_adapter.dump_json(self.history, indent=indent) ``` #### load_json ```python load_json(json_data: str | bytes | bytearray) -> None ``` Load the history from JSON. Source code in `pydantic_graph/pydantic_graph/persistence/in_mem.py` ```python def load_json(self, json_data: str | bytes | bytearray) -> None: """Load the history from JSON.""" assert self._snapshots_type_adapter is not None, 'type adapter must be set to use `load_json`' self.history = self._snapshots_type_adapter.validate_json(json_data) ``` ### FileStatePersistence Bases: `BaseStatePersistence[StateT, RunEndT]` File based state persistence that hold graph run state in a JSON file. Source code in `pydantic_graph/pydantic_graph/persistence/file.py` ````python @dataclass class FileStatePersistence(BaseStatePersistence[StateT, RunEndT]): """File based state persistence that hold graph run state in a JSON file.""" json_file: Path """Path to the JSON file where the snapshots are stored. You should use a different file for each graph run, but a single file should be reused for multiple steps of the same run. For example if you have a run ID of the form `run_123abc`, you might create a `FileStatePersistence` thus: ```py from pathlib import Path from pydantic_graph import FullStatePersistence run_id = 'run_123abc' persistence = FullStatePersistence(Path('runs') / f'{run_id}.json') ``` """ _snapshots_type_adapter: pydantic.TypeAdapter[list[Snapshot[StateT, RunEndT]]] | None = field( default=None, init=False, repr=False ) async def snapshot_node(self, state: StateT, next_node: BaseNode[StateT, Any, RunEndT]) -> None: await self._append_save(NodeSnapshot(state=state, node=next_node)) async def snapshot_node_if_new( self, snapshot_id: str, state: StateT, next_node: BaseNode[StateT, Any, RunEndT] ) -> None: async with self._lock(): snapshots = await self.load_all() if not any(s.id == snapshot_id for s in snapshots): # pragma: no branch await self._append_save(NodeSnapshot(state=state, node=next_node), lock=False) async def snapshot_end(self, state: StateT, end: End[RunEndT]) -> None: await self._append_save(EndSnapshot(state=state, result=end)) @asynccontextmanager async def record_run(self, snapshot_id: str) -> AsyncIterator[None]: async with self._lock(): snapshots = await self.load_all() try: snapshot = next(s for s in snapshots if s.id == snapshot_id) except StopIteration as e: raise LookupError(f'No snapshot found with id={snapshot_id!r}') from e assert isinstance(snapshot, NodeSnapshot), 'Only NodeSnapshot can be recorded' exceptions.GraphNodeStatusError.check(snapshot.status) snapshot.status = 'running' snapshot.start_ts = _utils.now_utc() await self._save(snapshots) start = perf_counter() try: yield except Exception: duration = perf_counter() - start async with self._lock(): await _graph_utils.run_in_executor(self._after_run_sync, snapshot_id, duration, 'error') raise else: snapshot.duration = perf_counter() - start async with self._lock(): await _graph_utils.run_in_executor(self._after_run_sync, snapshot_id, snapshot.duration, 'success') async def load_next(self) -> NodeSnapshot[StateT, RunEndT] | None: async with self._lock(): snapshots = await self.load_all() if snapshot := next((s for s in snapshots if isinstance(s, NodeSnapshot) and s.status == 'created'), None): snapshot.status = 'pending' await self._save(snapshots) return snapshot def should_set_types(self) -> bool: """Whether types need to be set.""" return self._snapshots_type_adapter is None def set_types(self, state_type: type[StateT], run_end_type: type[RunEndT]) -> None: self._snapshots_type_adapter = build_snapshot_list_type_adapter(state_type, run_end_type) async def load_all(self) -> list[Snapshot[StateT, RunEndT]]: return await _graph_utils.run_in_executor(self._load_sync) def _load_sync(self) -> list[Snapshot[StateT, RunEndT]]: assert self._snapshots_type_adapter is not None, 'snapshots type adapter must be set' try: content = self.json_file.read_bytes() except FileNotFoundError: return [] else: return self._snapshots_type_adapter.validate_json(content) def _after_run_sync(self, snapshot_id: str, duration: float, status: SnapshotStatus) -> None: snapshots = self._load_sync() snapshot = next(s for s in snapshots if s.id == snapshot_id) assert isinstance(snapshot, NodeSnapshot), 'Only NodeSnapshot can be recorded' snapshot.duration = duration snapshot.status = status self._save_sync(snapshots) async def _save(self, snapshots: list[Snapshot[StateT, RunEndT]]) -> None: await _graph_utils.run_in_executor(self._save_sync, snapshots) def _save_sync(self, snapshots: list[Snapshot[StateT, RunEndT]]) -> None: assert self._snapshots_type_adapter is not None, 'snapshots type adapter must be set' self.json_file.write_bytes(self._snapshots_type_adapter.dump_json(snapshots, indent=2)) async def _append_save(self, snapshot: Snapshot[StateT, RunEndT], *, lock: bool = True) -> None: assert self._snapshots_type_adapter is not None, 'snapshots type adapter must be set' async with AsyncExitStack() as stack: if lock: await stack.enter_async_context(self._lock()) snapshots = await self.load_all() snapshots.append(snapshot) await self._save(snapshots) @asynccontextmanager async def _lock(self, *, timeout: float = 1.0) -> AsyncIterator[None]: """Lock a file by checking and writing a `.pydantic-graph-persistence-lock` to it. Args: timeout: how long to wait for the lock Returns: an async context manager that holds the lock """ lock_file = self.json_file.parent / f'{self.json_file.name}.pydantic-graph-persistence-lock' lock_id = secrets.token_urlsafe().encode() await asyncio.wait_for(_get_lock(lock_file, lock_id), timeout=timeout) try: yield finally: await _graph_utils.run_in_executor(lock_file.unlink, missing_ok=True) ```` #### json_file ```python json_file: Path ``` Path to the JSON file where the snapshots are stored. You should use a different file for each graph run, but a single file should be reused for multiple steps of the same run. For example if you have a run ID of the form `run_123abc`, you might create a `FileStatePersistence` thus: ```py from pathlib import Path from pydantic_graph import FullStatePersistence run_id = 'run_123abc' persistence = FullStatePersistence(Path('runs') / f'{run_id}.json') ``` #### should_set_types ```python should_set_types() -> bool ``` Whether types need to be set. Source code in `pydantic_graph/pydantic_graph/persistence/file.py` ```python def should_set_types(self) -> bool: """Whether types need to be set.""" return self._snapshots_type_adapter is None ``` # Evals # Evals "Evals" refers to evaluating a model's performance for a specific application. Warning Unlike unit tests, evals are an emerging art/science; anyone who claims to know for sure exactly how your evals should be defined can safely be ignored. Pydantic Evals is a powerful evaluation framework designed to help you systematically test and evaluate the performance and accuracy of the systems you build, especially when working with LLMs. We've designed Pydantic Evals to be useful while not being too opinionated since we (along with everyone else) are still figuring out best practices. We'd love your [feedback](../help/) on the package and how we can improve it. In Beta Pydantic Evals support was [introduced](https://github.com/pydantic/pydantic-ai/pull/935) in v0.0.47 and is currently in beta. The API is subject to change and the documentation is incomplete. ## Installation To install the Pydantic Evals package, run: ```bash pip install pydantic-evals ``` ```bash uv add pydantic-evals ``` `pydantic-evals` does not depend on `pydantic-ai`, but has an optional dependency on `logfire` if you'd like to use OpenTelemetry traces in your evals, or send evaluation results to [logfire](https://pydantic.dev/logfire). ```bash pip install 'pydantic-evals[logfire]' ``` ```bash uv add 'pydantic-evals[logfire]' ``` ## Datasets and Cases In Pydantic Evals, everything begins with `Dataset`s and `Case`s: - Case: A single test scenario corresponding to "task" inputs. Can also optionally have a name, expected outputs, metadata, and evaluators. - Dataset: A collection of test cases designed for the evaluation of a specific task or function. simple_eval_dataset.py ```python from pydantic_evals import Case, Dataset case1 = Case( name='simple_case', inputs='What is the capital of France?', expected_output='Paris', metadata={'difficulty': 'easy'}, ) dataset = Dataset(cases=[case1]) ``` *(This example is complete, it can be run "as is")* ## Evaluators Evaluators are the components that analyze and score the results of your task when tested against a case. Pydantic Evals includes several built-in evaluators and allows you to create custom evaluators: simple_eval_evaluator.py ```python from dataclasses import dataclass from simple_eval_dataset import dataset from pydantic_evals.evaluators import Evaluator, EvaluatorContext from pydantic_evals.evaluators.common import IsInstance dataset.add_evaluator(IsInstance(type_name='str')) # (1)! @dataclass class MyEvaluator(Evaluator): async def evaluate(self, ctx: EvaluatorContext[str, str]) -> float: # (2)! if ctx.output == ctx.expected_output: return 1.0 elif ( isinstance(ctx.output, str) and ctx.expected_output.lower() in ctx.output.lower() ): return 0.8 else: return 0.0 dataset.add_evaluator(MyEvaluator()) ``` 1. You can add built-in evaluators to a dataset using the add_evaluator method. 1. This custom evaluator returns a simple score based on whether the output matches the expected output. *(This example is complete, it can be run "as is")* ## Evaluation Process The evaluation process involves running a task against all cases in a dataset: Putting the above two examples together and using the more declarative `evaluators` kwarg to Dataset: simple_eval_complete.py ```python from pydantic_evals import Case, Dataset from pydantic_evals.evaluators import Evaluator, EvaluatorContext, IsInstance case1 = Case( # (1)! name='simple_case', inputs='What is the capital of France?', expected_output='Paris', metadata={'difficulty': 'easy'}, ) class MyEvaluator(Evaluator[str, str]): def evaluate(self, ctx: EvaluatorContext[str, str]) -> float: if ctx.output == ctx.expected_output: return 1.0 elif ( isinstance(ctx.output, str) and ctx.expected_output.lower() in ctx.output.lower() ): return 0.8 else: return 0.0 dataset = Dataset( cases=[case1], evaluators=[IsInstance(type_name='str'), MyEvaluator()], # (3)! ) async def guess_city(question: str) -> str: # (4)! return 'Paris' report = dataset.evaluate_sync(guess_city) # (5)! report.print(include_input=True, include_output=True, include_durations=False) # (6)! """ Evaluation Summary: guess_city ┏━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┓ ┃ Case ID ┃ Inputs ┃ Outputs ┃ Scores ┃ Assertions ┃ ┡━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━┩ │ simple_case │ What is the capital of France? │ Paris │ MyEvaluator: 1.00 │ ✔ │ ├─────────────┼────────────────────────────────┼─────────┼───────────────────┼────────────┤ │ Averages │ │ │ MyEvaluator: 1.00 │ 100.0% ✔ │ └─────────────┴────────────────────────────────┴─────────┴───────────────────┴────────────┘ """ ``` 1. Create a test case as above 1. Also create a custom evaluator function as above 1. Create a Dataset with test cases, also set the evaluators when creating the dataset 1. Our function to evaluate. 1. Run the evaluation with evaluate_sync, which runs the function against all test cases in the dataset, and returns an EvaluationReport object. 1. Print the report with print, which shows the results of the evaluation, including input and output. We have omitted duration here just to keep the printed output from changing from run to run. *(This example is complete, it can be run "as is")* ## Evaluation with `LLMJudge` In this example we evaluate a method for generating recipes based on customer orders. judge_recipes.py ```python from __future__ import annotations from typing import Any from pydantic import BaseModel from pydantic_ai import Agent, format_as_xml from pydantic_evals import Case, Dataset from pydantic_evals.evaluators import IsInstance, LLMJudge class CustomerOrder(BaseModel): # (1)! dish_name: str dietary_restriction: str | None = None class Recipe(BaseModel): ingredients: list[str] steps: list[str] recipe_agent = Agent( 'groq:llama-3.3-70b-versatile', output_type=Recipe, system_prompt=( 'Generate a recipe to cook the dish that meets the dietary restrictions.' ), ) async def transform_recipe(customer_order: CustomerOrder) -> Recipe: # (2)! r = await recipe_agent.run(format_as_xml(customer_order)) return r.output recipe_dataset = Dataset[CustomerOrder, Recipe, Any]( # (3)! cases=[ Case( name='vegetarian_recipe', inputs=CustomerOrder( dish_name='Spaghetti Bolognese', dietary_restriction='vegetarian' ), expected_output=None, # (4) metadata={'focus': 'vegetarian'}, evaluators=( LLMJudge( # (5)! rubric='Recipe should not contain meat or animal products', ), ), ), Case( name='gluten_free_recipe', inputs=CustomerOrder( dish_name='Chocolate Cake', dietary_restriction='gluten-free' ), expected_output=None, metadata={'focus': 'gluten-free'}, # Case-specific evaluator with a focused rubric evaluators=( LLMJudge( rubric='Recipe should not contain gluten or wheat products', ), ), ), ], evaluators=[ # (6)! IsInstance(type_name='Recipe'), LLMJudge( rubric='Recipe should have clear steps and relevant ingredients', include_input=True, model='anthropic:claude-3-7-sonnet-latest', # (7)! ), ], ) report = recipe_dataset.evaluate_sync(transform_recipe) print(report) """ Evaluation Summary: transform_recipe ┏━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━┓ ┃ Case ID ┃ Assertions ┃ Duration ┃ ┡━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━┩ │ vegetarian_recipe │ ✔✔✔ │ 10ms │ ├────────────────────┼────────────┼──────────┤ │ gluten_free_recipe │ ✔✔✔ │ 10ms │ ├────────────────────┼────────────┼──────────┤ │ Averages │ 100.0% ✔ │ 10ms │ └────────────────────┴────────────┴──────────┘ """ ``` 1. Define models for our task — Input for recipe generation task and output of the task. 1. Define our recipe generation function - this is the task we want to evaluate. 1. Create a dataset with different test cases and different rubrics. 1. No expected output, we'll let the LLM judge the quality. 1. Case-specific evaluator with a focused rubric using LLMJudge. 1. Dataset-level evaluators that apply to all cases, including a general quality rubric for all recipes 1. By default `LLMJudge` uses `openai:gpt-4o`, here we use a specific Anthropic model. *(This example is complete, it can be run "as is")* ## Saving and Loading Datasets Datasets can be saved to and loaded from YAML or JSON files. save_load_dataset_example.py ```python from pathlib import Path from judge_recipes import CustomerOrder, Recipe, recipe_dataset from pydantic_evals import Dataset recipe_transforms_file = Path('recipe_transform_tests.yaml') recipe_dataset.to_file(recipe_transforms_file) # (1)! print(recipe_transforms_file.read_text()) """ # yaml-language-server: $schema=recipe_transform_tests_schema.json cases: - name: vegetarian_recipe inputs: dish_name: Spaghetti Bolognese dietary_restriction: vegetarian metadata: focus: vegetarian evaluators: - LLMJudge: Recipe should not contain meat or animal products - name: gluten_free_recipe inputs: dish_name: Chocolate Cake dietary_restriction: gluten-free metadata: focus: gluten-free evaluators: - LLMJudge: Recipe should not contain gluten or wheat products evaluators: - IsInstance: Recipe - LLMJudge: rubric: Recipe should have clear steps and relevant ingredients model: anthropic:claude-3-7-sonnet-latest include_input: true """ # Load dataset from file loaded_dataset = Dataset[CustomerOrder, Recipe, dict].from_file(recipe_transforms_file) print(f'Loaded dataset with {len(loaded_dataset.cases)} cases') #> Loaded dataset with 2 cases ``` *(This example is complete, it can be run "as is")* ## Parallel Evaluation You can control concurrency during evaluation (this might be useful to prevent exceeding a rate limit): parallel_evaluation_example.py ```python import asyncio import time from pydantic_evals import Case, Dataset # Create a dataset with multiple test cases dataset = Dataset( cases=[ Case( name=f'case_{i}', inputs=i, expected_output=i * 2, ) for i in range(5) ] ) async def double_number(input_value: int) -> int: """Function that simulates work by sleeping for a tenth of a second before returning double the input.""" await asyncio.sleep(0.1) # Simulate work return input_value * 2 # Run evaluation with unlimited concurrency t0 = time.time() report_default = dataset.evaluate_sync(double_number) print(f'Evaluation took less than 0.5s: {time.time() - t0 < 0.5}') #> Evaluation took less than 0.5s: True report_default.print(include_input=True, include_output=True, include_durations=False) # (1)! """ Evaluation Summary: double_number ┏━━━━━━━━━━┳━━━━━━━━┳━━━━━━━━━┓ ┃ Case ID ┃ Inputs ┃ Outputs ┃ ┡━━━━━━━━━━╇━━━━━━━━╇━━━━━━━━━┩ │ case_0 │ 0 │ 0 │ ├──────────┼────────┼─────────┤ │ case_1 │ 1 │ 2 │ ├──────────┼────────┼─────────┤ │ case_2 │ 2 │ 4 │ ├──────────┼────────┼─────────┤ │ case_3 │ 3 │ 6 │ ├──────────┼────────┼─────────┤ │ case_4 │ 4 │ 8 │ ├──────────┼────────┼─────────┤ │ Averages │ │ │ └──────────┴────────┴─────────┘ """ # Run evaluation with limited concurrency t0 = time.time() report_limited = dataset.evaluate_sync(double_number, max_concurrency=1) print(f'Evaluation took more than 0.5s: {time.time() - t0 > 0.5}') #> Evaluation took more than 0.5s: True report_limited.print(include_input=True, include_output=True, include_durations=False) # (2)! """ Evaluation Summary: double_number ┏━━━━━━━━━━┳━━━━━━━━┳━━━━━━━━━┓ ┃ Case ID ┃ Inputs ┃ Outputs ┃ ┡━━━━━━━━━━╇━━━━━━━━╇━━━━━━━━━┩ │ case_0 │ 0 │ 0 │ ├──────────┼────────┼─────────┤ │ case_1 │ 1 │ 2 │ ├──────────┼────────┼─────────┤ │ case_2 │ 2 │ 4 │ ├──────────┼────────┼─────────┤ │ case_3 │ 3 │ 6 │ ├──────────┼────────┼─────────┤ │ case_4 │ 4 │ 8 │ ├──────────┼────────┼─────────┤ │ Averages │ │ │ └──────────┴────────┴─────────┘ """ ``` 1. We have omitted duration here just to keep the printed output from changing from run to run. 1. We have omitted duration here just to keep the printed output from changing from run to run. *(This example is complete, it can be run "as is")* ## OpenTelemetry Integration Pydantic Evals integrates with OpenTelemetry for tracing. The EvaluatorContext includes a property called `span_tree` which returns a SpanTree. The `SpanTree` provides a way to query and analyze the spans generated during function execution. This provides a way to access the results of instrumentation during evaluation. Note If you just want to write unit tests that ensure that specific spans are produced during calls to your evaluation task, it's usually better to just use the `logfire.testing.capfire` fixture directly. There are two main ways this is useful. opentelemetry_example.py ```python import asyncio from typing import Any import logfire from pydantic_evals import Case, Dataset from pydantic_evals.evaluators import Evaluator from pydantic_evals.evaluators.context import EvaluatorContext from pydantic_evals.otel.span_tree import SpanQuery logfire.configure( # ensure that an OpenTelemetry tracer is configured send_to_logfire='if-token-present' ) class SpanTracingEvaluator(Evaluator[str, str]): """Evaluator that analyzes the span tree generated during function execution.""" def evaluate(self, ctx: EvaluatorContext[str, str]) -> dict[str, Any]: # Get the span tree from the context span_tree = ctx.span_tree if span_tree is None: return {'has_spans': False, 'performance_score': 0.0} # Find all spans with "processing" in the name processing_spans = span_tree.find(lambda node: 'processing' in node.name) # Calculate total processing time total_processing_time = sum( (span.duration.total_seconds() for span in processing_spans), 0.0 ) # Check for error spans error_query: SpanQuery = {'name_contains': 'error'} has_errors = span_tree.any(error_query) # Calculate a performance score (lower is better) performance_score = 1.0 if total_processing_time < 1.0 else 0.5 return { 'has_spans': True, 'has_errors': has_errors, 'performance_score': 0 if has_errors else performance_score, } async def process_text(text: str) -> str: """Function that processes text with OpenTelemetry instrumentation.""" with logfire.span('process_text'): # Simulate initial processing with logfire.span('text_processing'): await asyncio.sleep(0.1) processed = text.strip().lower() # Simulate additional processing with logfire.span('additional_processing'): if 'error' in processed: with logfire.span('error_handling'): logfire.error(f'Error detected in text: {text}') return f'Error processing: {text}' await asyncio.sleep(0.2) processed = processed.replace(' ', '_') return f'Processed: {processed}' # Create test cases dataset = Dataset( cases=[ Case( name='normal_text', inputs='Hello World', expected_output='Processed: hello_world', ), Case( name='text_with_error', inputs='Contains error marker', expected_output='Error processing: Contains error marker', ), ], evaluators=[SpanTracingEvaluator()], ) # Run evaluation - spans are automatically captured since logfire is configured report = dataset.evaluate_sync(process_text) # Print the report report.print(include_input=True, include_output=True, include_durations=False) # (1)! """ Evaluation Summary: process_text ┏━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┓ ┃ Case ID ┃ Inputs ┃ Outputs ┃ Scores ┃ Assertions ┃ ┡━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━┩ │ normal_text │ Hello World │ Processed: hello_world │ performance_score: 1.00 │ ✔✗ │ ├─────────────────┼───────────────────────┼─────────────────────────────────────────┼──────────────────────────┼────────────┤ │ text_with_error │ Contains error marker │ Error processing: Contains error marker │ performance_score: 0 │ ✔✔ │ ├─────────────────┼───────────────────────┼─────────────────────────────────────────┼──────────────────────────┼────────────┤ │ Averages │ │ │ performance_score: 0.500 │ 75.0% ✔ │ └─────────────────┴───────────────────────┴─────────────────────────────────────────┴──────────────────────────┴────────────┘ """ ``` 1. We have omitted duration here just to keep the printed output from changing from run to run. *(This example is complete, it can be run "as is")* ## Generating Test Datasets Pydantic Evals allows you to generate test datasets using LLMs with generate_dataset. Datasets can be generated in either JSON or YAML format, in both cases a JSON schema file is generated alongside the dataset and referenced in the dataset, so you should get type checking and auto-completion in your editor. generate_dataset_example.py ```python from __future__ import annotations from pathlib import Path from pydantic import BaseModel, Field from pydantic_evals import Dataset from pydantic_evals.generation import generate_dataset class QuestionInputs(BaseModel, use_attribute_docstrings=True): # (1)! """Model for question inputs.""" question: str """A question to answer""" context: str | None = None """Optional context for the question""" class AnswerOutput(BaseModel, use_attribute_docstrings=True): # (2)! """Model for expected answer outputs.""" answer: str """The answer to the question""" confidence: float = Field(ge=0, le=1) """Confidence level (0-1)""" class MetadataType(BaseModel, use_attribute_docstrings=True): # (3)! """Metadata model for test cases.""" difficulty: str """Difficulty level (easy, medium, hard)""" category: str """Question category""" async def main(): dataset = await generate_dataset( # (4)! dataset_type=Dataset[QuestionInputs, AnswerOutput, MetadataType], n_examples=2, extra_instructions=""" Generate question-answer pairs about world capitals and landmarks. Make sure to include both easy and challenging questions. """, ) output_file = Path('questions_cases.yaml') dataset.to_file(output_file) # (5)! print(output_file.read_text()) """ # yaml-language-server: $schema=questions_cases_schema.json cases: - name: Easy Capital Question inputs: question: What is the capital of France? metadata: difficulty: easy category: Geography expected_output: answer: Paris confidence: 0.95 evaluators: - EqualsExpected - name: Challenging Landmark Question inputs: question: Which world-famous landmark is located on the banks of the Seine River? metadata: difficulty: hard category: Landmarks expected_output: answer: Eiffel Tower confidence: 0.9 evaluators: - EqualsExpected """ ``` 1. Define the schema for the inputs to the task. 1. Define the schema for the expected outputs of the task. 1. Define the schema for the metadata of the test cases. 1. Call generate_dataset to create a Dataset with 2 cases confirming to the schema. 1. Save the dataset to a YAML file, this will also write `questions_cases_schema.json` with the schema JSON schema for `questions_cases.yaml` to make editing easier. The magic `yaml-language-server` comment is supported by at least vscode, jetbrains/pycharm (more details [here](https://github.com/redhat-developer/yaml-language-server#using-inlined-schema)). *(This example is complete, it can be run "as is" — you'll need to add `asyncio.run(main(answer))` to run `main`)* You can also write datasets as JSON files: generate_dataset_example_json.py ```python from pathlib import Path from generate_dataset_example import AnswerOutput, MetadataType, QuestionInputs from pydantic_evals import Dataset from pydantic_evals.generation import generate_dataset async def main(): dataset = await generate_dataset( # (1)! dataset_type=Dataset[QuestionInputs, AnswerOutput, MetadataType], n_examples=2, extra_instructions=""" Generate question-answer pairs about world capitals and landmarks. Make sure to include both easy and challenging questions. """, ) output_file = Path('questions_cases.json') dataset.to_file(output_file) # (2)! print(output_file.read_text()) """ { "$schema": "questions_cases_schema.json", "cases": [ { "name": "Easy Capital Question", "inputs": { "question": "What is the capital of France?" }, "metadata": { "difficulty": "easy", "category": "Geography" }, "expected_output": { "answer": "Paris", "confidence": 0.95 }, "evaluators": [ "EqualsExpected" ] }, { "name": "Challenging Landmark Question", "inputs": { "question": "Which world-famous landmark is located on the banks of the Seine River?" }, "metadata": { "difficulty": "hard", "category": "Landmarks" }, "expected_output": { "answer": "Eiffel Tower", "confidence": 0.9 }, "evaluators": [ "EqualsExpected" ] } ] } """ ``` 1. Generate the Dataset exactly as above. 1. Save the dataset to a JSON file, this will also write `questions_cases_schema.json` with th JSON schema for `questions_cases.json`. This time the `$schema` key is included in the JSON file to define the schema for IDEs to use while you edit the file, there's no formal spec for this, but it works in vscode and pycharm and is discussed at length in [json-schema-org/json-schema-spec#828](https://github.com/json-schema-org/json-schema-spec/issues/828). *(This example is complete, it can be run "as is" — you'll need to add `asyncio.run(main(answer))` to run `main`)* ## Integration with Logfire Pydantic Evals is implemented using OpenTelemetry to record traces of the evaluation process. These traces contain all the information included in the terminal output as attributes, but also include full tracing from the executions of the evaluation task function. You can send these traces to any OpenTelemetry-compatible backend, including [Pydantic Logfire](https://logfire.pydantic.dev/docs). All you need to do is configure Logfire via `logfire.configure`: logfire_integration.py ```python import logfire from judge_recipes import recipe_dataset, transform_recipe logfire.configure( send_to_logfire='if-token-present', # (1)! environment='development', # (2)! service_name='evals', # (3)! ) recipe_dataset.evaluate_sync(transform_recipe) ``` 1. The `send_to_logfire` argument controls when traces are sent to Logfire. You can set it to `'if-token-present'` to send data to Logfire only if the `LOGFIRE_TOKEN` environment variable is set. See the [Logfire configuration docs](https://logfire.pydantic.dev/docs/reference/configuration/) for more details. 1. The `environment` argument sets the environment for the traces. It's a good idea to set this to `'development'` when running tests or evaluations and sending data to a project with production data, to make it easier to filter these traces out while reviewing data from your production environment(s). 1. The `service_name` argument sets the service name for the traces. This is displayed in the Logfire UI to help you identify the source of the associated spans. Logfire has some special integration with Pydantic Evals traces, including a table view of the evaluation results on the evaluation root span (which is generated in each call to Dataset.evaluate): and a detailed view of the inputs and outputs for the execution of each case: In addition, any OpenTelemetry spans generated during the evaluation process will be sent to Logfire, allowing you to visualize the full execution of the code called during the evaluation process: This can be especially helpful when attempting to write evaluators that make use of the `span_tree` property of the EvaluatorContext, as described in the [OpenTelemetry Integration](#opentelemetry-integration) section above. This allows you to write evaluations that depend on information about which code paths were executed during the call to the task function without needing to manually instrument the code being evaluated, as long as the code being evaluated is already adequately instrumented with OpenTelemetry. In the case of PydanticAI agents, for example, this can be used to ensure specific tools are (or are not) called during the execution of specific cases. Using OpenTelemetry in this way also means that all data used to evaluate the task executions will be accessible in the traces produced by production runs of the code, making it straightforward to perform the same evaluations on production data. # MCP # Model Context Protocol (MCP) PydanticAI supports [Model Context Protocol (MCP)](https://modelcontextprotocol.io) in three ways: 1. [Agents](../agents/) act as an MCP Client, connecting to MCP servers to use their tools, [learn more …](client/) 1. Agents can be used within MCP servers, [learn more …](server/) 1. As part of PydanticAI, we're building a number of MCP servers, [see below](#mcp-servers) ## What is MCP? The Model Context Protocol is a standardized protocol that allow AI applications (including programmatic agents like PydanticAI, coding agents like [cursor](https://www.cursor.com/), and desktop applications like [Claude Desktop](https://claude.ai/download)) to connect to external tools and services using a common interface. As with other protocols, the dream of MCP is that a wide range of applications can speak to each other without the need for specific integrations. There is a great list of MCP servers at [github.com/modelcontextprotocol/servers](https://github.com/modelcontextprotocol/servers). Some examples of what this means: - PydanticAI could use a web search service implemented as an MCP server to implement a deep research agent - Cursor could connect to the [Pydantic Logfire](https://github.com/pydantic/logfire-mcp) MCP server to search logs, traces and metrics to gain context while fixing a bug - PydanticAI, or any other MCP client could connect to our [Run Python](run-python/) MCP server to run arbitrary Python code in a sandboxed environment ## MCP Servers To add functionality to PydanticAI while making it as widely usable as possible, we're implementing some functionality as MCP servers. So far, we've only implemented one MCP server as part of PydanticAI: - [Run Python](run-python/): A sandboxed Python interpreter that can run arbitrary code, with a focus on security and safety. # Client PydanticAI can act as an [MCP client](https://modelcontextprotocol.io/quickstart/client), connecting to MCP servers to use their tools. ## Install You need to either install [`pydantic-ai`](../../install/), or[`pydantic-ai-slim`](../../install/#slim-install) with the `mcp` optional group: ```bash pip install "pydantic-ai-slim[mcp]" ``` ```bash uv add "pydantic-ai-slim[mcp]" ``` Note MCP integration requires Python 3.10 or higher. ## Usage PydanticAI comes with two ways to connect to MCP servers: - MCPServerSSE which connects to an MCP server using the [HTTP SSE](https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/transports/#http-with-sse) transport - MCPServerStreamableHTTP which connects to an MCP server using the [Streamable HTTP](https://modelcontextprotocol.io/introduction#streamable-http) transport - MCPServerStdio which runs the server as a subprocess and connects to it using the [stdio](https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/transports/#stdio) transport Examples of both are shown below; [mcp-run-python](../run-python/) is used as the MCP server in both examples. ### SSE Client MCPServerSSE connects over HTTP using the [HTTP + Server Sent Events transport](https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/transports/#http-with-sse) to a server. Note MCPServerSSE requires an MCP server to be running and accepting HTTP connections before calling agent.run_mcp_servers(). Running the server is not managed by PydanticAI. The name "HTTP" is used since this implementation will be adapted in future to use the new [Streamable HTTP](https://github.com/modelcontextprotocol/specification/pull/206) currently in development. Before creating the SSE client, we need to run the server (docs [here](../run-python/)): terminal (run sse server) ```bash deno run \ -N -R=node_modules -W=node_modules --node-modules-dir=auto \ jsr:@pydantic/mcp-run-python sse ``` mcp_sse_client.py ```python from pydantic_ai import Agent from pydantic_ai.mcp import MCPServerSSE server = MCPServerSSE(url='http://localhost:3001/sse') # (1)! agent = Agent('openai:gpt-4o', mcp_servers=[server]) # (2)! async def main(): async with agent.run_mcp_servers(): # (3)! result = await agent.run('How many days between 2000-01-01 and 2025-03-18?') print(result.output) #> There are 9,208 days between January 1, 2000, and March 18, 2025. ``` 1. Define the MCP server with the URL used to connect. 1. Create an agent with the MCP server attached. 1. Create a client session to connect to the server. *(This example is complete, it can be run "as is" with Python 3.10+ — you'll need to add `asyncio.run(main())` to run `main`)* **What's happening here?** - The model is receiving the prompt "how many days between 2000-01-01 and 2025-03-18?" - The model decides "Oh, I've got this `run_python_code` tool, that will be a good way to answer this question", and writes some python code to calculate the answer. - The model returns a tool call - PydanticAI sends the tool call to the MCP server using the SSE transport - The model is called again with the return value of running the code - The model returns the final answer You can visualise this clearly, and even see the code that's run by adding three lines of code to instrument the example with [logfire](https://logfire.pydantic.dev/docs): mcp_sse_client_logfire.py ```python import logfire logfire.configure() logfire.instrument_pydantic_ai() ``` Will display as follows: ### Streamable HTTP Client MCPServerStreamableHTTP connects over HTTP using the [Streamable HTTP](https://modelcontextprotocol.io/introduction#streamable-http) transport to a server. Note MCPServerStreamableHTTP requires an MCP server to be running and accepting HTTP connections before calling agent.run_mcp_servers(). Running the server is not managed by PydanticAI. Before creating the Streamable HTTP client, we need to run a server that supports the Streamable HTTP transport. streamable_http_server.py ```python from mcp.server.fastmcp import FastMCP app = FastMCP() @app.tool() def add(a: int, b: int) -> int: return a + b if __name__ == '__main__': app.run(transport='streamable-http') ``` Then we can create the client: mcp_streamable_http_client.py ```python from pydantic_ai import Agent from pydantic_ai.mcp import MCPServerStreamableHTTP server = MCPServerStreamableHTTP('http://localhost:8000/mcp') # (1)! agent = Agent('openai:gpt-4o', mcp_servers=[server]) # (2)! async def main(): async with agent.run_mcp_servers(): # (3)! result = await agent.run('How many days between 2000-01-01 and 2025-03-18?') print(result.output) #> There are 9,208 days between January 1, 2000, and March 18, 2025. ``` 1. Define the MCP server with the URL used to connect. 1. Create an agent with the MCP server attached. 1. Create a client session to connect to the server. *(This example is complete, it can be run "as is" with Python 3.10+ — you'll need to add `asyncio.run(main())` to run `main`)* ### MCP "stdio" Server The other transport offered by MCP is the [stdio transport](https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/transports/#stdio) where the server is run as a subprocess and communicates with the client over `stdin` and `stdout`. In this case, you'd use the MCPServerStdio class. Note When using MCPServerStdio servers, the agent.run_mcp_servers() context manager is responsible for starting and stopping the server. mcp_stdio_client.py ```python from pydantic_ai import Agent from pydantic_ai.mcp import MCPServerStdio server = MCPServerStdio( # (1)! 'deno', args=[ 'run', '-N', '-R=node_modules', '-W=node_modules', '--node-modules-dir=auto', 'jsr:@pydantic/mcp-run-python', 'stdio', ] ) agent = Agent('openai:gpt-4o', mcp_servers=[server]) async def main(): async with agent.run_mcp_servers(): result = await agent.run('How many days between 2000-01-01 and 2025-03-18?') print(result.output) #> There are 9,208 days between January 1, 2000, and March 18, 2025. ``` 1. See [MCP Run Python](../run-python/) for more information. ## Tool call customisation The MCP servers provide the ability to set a `process_tool_call` which allows the customisation of tool call requests and their responses. A common use case for this is to inject metadata to the requests which the server call needs. mcp_process_tool_call.py ```python from typing import Any from pydantic_ai import Agent from pydantic_ai.mcp import CallToolFunc, MCPServerStdio, ToolResult from pydantic_ai.models.test import TestModel from pydantic_ai.tools import RunContext async def process_tool_call( ctx: RunContext[int], call_tool: CallToolFunc, tool_name: str, args: dict[str, Any], ) -> ToolResult: """A tool call processor that passes along the deps.""" return await call_tool(tool_name, args, metadata={'deps': ctx.deps}) server = MCPServerStdio('python', ['mcp_server.py'], process_tool_call=process_tool_call) agent = Agent( model=TestModel(call_tools=['echo_deps']), deps_type=int, mcp_servers=[server] ) async def main(): async with agent.run_mcp_servers(): result = await agent.run('Echo with deps set to 42', deps=42) print(result.output) #> {"echo_deps":{"echo":"This is an echo message","deps":42}} ``` ## Using Tool Prefixes to Avoid Naming Conflicts When connecting to multiple MCP servers that might provide tools with the same name, you can use the `tool_prefix` parameter to avoid naming conflicts. This parameter adds a prefix to all tool names from a specific server. ### How It Works - If `tool_prefix` is set, all tools from that server will be prefixed with `{tool_prefix}_` - When listing tools, the prefixed names are shown to the model - When calling tools, the prefix is automatically removed before sending the request to the server This allows you to use multiple servers that might have overlapping tool names without conflicts. ### Example with HTTP Server mcp_tool_prefix_http_client.py ```python from pydantic_ai import Agent from pydantic_ai.mcp import MCPServerSSE # Create two servers with different prefixes weather_server = MCPServerSSE( url='http://localhost:3001/sse', tool_prefix='weather' # Tools will be prefixed with 'weather_' ) calculator_server = MCPServerSSE( url='http://localhost:3002/sse', tool_prefix='calc' # Tools will be prefixed with 'calc_' ) # Both servers might have a tool named 'get_data', but they'll be exposed as: # - 'weather_get_data' # - 'calc_get_data' agent = Agent('openai:gpt-4o', mcp_servers=[weather_server, calculator_server]) ``` ### Example with Stdio Server mcp_tool_prefix_stdio_client.py ```python from pydantic_ai import Agent from pydantic_ai.mcp import MCPServerStdio python_server = MCPServerStdio( 'deno', args=[ 'run', '-N', 'jsr:@pydantic/mcp-run-python', 'stdio', ], tool_prefix='py' # Tools will be prefixed with 'py_' ) js_server = MCPServerStdio( 'node', args=[ 'run', 'mcp-js-server.js', 'stdio', ], tool_prefix='js' # Tools will be prefixed with 'js_' ) agent = Agent('openai:gpt-4o', mcp_servers=[python_server, js_server]) ``` When the model interacts with these servers, it will see the prefixed tool names, but the prefixes will be automatically handled when making tool calls. ## MCP Sampling What is MCP Sampling? In MCP [sampling](https://modelcontextprotocol.io/docs/concepts/sampling) is a system by which an MCP server can make LLM calls via the MCP client - effectively proxying requests to an LLM via the client over whatever transport is being used. Sampling is extremely useful when MCP servers need to use Gen AI but you don't want to provision them each with their own LLM credentials or when a public MCP server would like the connecting client to pay for LLM calls. Confusingly it has nothing to do with the concept of "sampling" in observability, or frankly the concept of "sampling" in any other domain. Sampling Diagram Here's a mermaid diagram that may or may not make the data flow clearer: ``` sequenceDiagram participant LLM participant MCP_Client as MCP client participant MCP_Server as MCP server MCP_Client->>LLM: LLM call LLM->>MCP_Client: LLM tool call response MCP_Client->>MCP_Server: tool call MCP_Server->>MCP_Client: sampling "create message" MCP_Client->>LLM: LLM call LLM->>MCP_Client: LLM text response MCP_Client->>MCP_Server: sampling response MCP_Server->>MCP_Client: tool call response ``` Pydantic AI supports sampling as both a client and server. See the [server](../server/#mcp-sampling) documentation for details on how to use sampling within a server. Sampling is automatically supported by Pydantic AI agents when they act as a client. Let's say we have an MCP server that wants to use sampling (in this case to generate an SVG as per the tool arguments). Sampling MCP Server generate_svg.py ````python import re from pathlib import Path from mcp import SamplingMessage from mcp.server.fastmcp import Context, FastMCP from mcp.types import TextContent app = FastMCP() @app.tool() async def image_generator(ctx: Context, subject: str, style: str) -> str: prompt = f'{subject=} {style=}' # `ctx.session.create_message` is the sampling call result = await ctx.session.create_message( [SamplingMessage(role='user', content=TextContent(type='text', text=prompt))], max_tokens=1_024, system_prompt='Generate an SVG image as per the user input', ) assert isinstance(result.content, TextContent) path = Path(f'{subject}_{style}.svg') # remove triple backticks if the svg was returned within markdown if m := re.search(r'^```\w*$(.+?)```$', result.content.text, re.S | re.M): path.write_text(m.group(1)) else: path.write_text(result.content.text) return f'See {path}' if __name__ == '__main__': # run the server via stdio app.run() ```` Using this server with an `Agent` will automatically allow sampling: sampling_mcp_client.py ```python from pydantic_ai import Agent from pydantic_ai.mcp import MCPServerStdio server = MCPServerStdio(command='python', args=['generate_svg.py']) agent = Agent('openai:gpt-4o', mcp_servers=[server]) async def main(): async with agent.run_mcp_servers(): result = await agent.run('Create an image of a robot in a punk style.') print(result.output) #> Image file written to robot_punk.svg. ``` *(This example is complete, it can be run "as is" with Python 3.10+)* You can disallow sampling by setting allow_sampling=False when creating the server reference, e.g.: sampling_disallowed.py ```python from pydantic_ai.mcp import MCPServerStdio server = MCPServerStdio( command='python', args=['generate_svg.py'], allow_sampling=False, ) ``` # MCP Run Python The **MCP Run Python** package is an MCP server that allows agents to execute Python code in a secure, sandboxed environment. It uses [Pyodide](https://pyodide.org/) to run Python code in a JavaScript environment with [Deno](https://deno.com/), isolating execution from the host system. ## Features - **Secure Execution**: Run Python code in a sandboxed WebAssembly environment - **Package Management**: Automatically detects and installs required dependencies - **Complete Results**: Captures standard output, standard error, and return values - **Asynchronous Support**: Runs async code properly - **Error Handling**: Provides detailed error reports for debugging ## Installation Switch from npx to deno We previously distributed `mcp-run-python` as an `npm` package to use via `npx`. We now recommend using `deno` instead as it provides better sandboxing and security. The MCP Run Python server is distributed as a [JSR package](https://jsr.io/@pydantic/mcp-run-python) and can be run directly using [`deno run`](https://deno.com/): terminal ```bash deno run \ -N -R=node_modules -W=node_modules --node-modules-dir=auto \ jsr:@pydantic/mcp-run-python [stdio|sse|warmup] ``` where: - `-N -R=node_modules -W=node_modules` (alias of `--allow-net --allow-read=node_modules --allow-write=node_modules`) allows network access and read+write access to `./node_modules`. These are required so Pyodide can download and cache the Python standard library and packages - `--node-modules-dir=auto` tells deno to use a local `node_modules` directory - `stdio` runs the server with the [Stdio MCP transport](https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/transports/#stdio) — suitable for running the process as a subprocess locally - `sse` runs the server with the [SSE MCP transport](https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/transports/#http-with-sse) — running the server as an HTTP server to connect locally or remotely - `warmup` will run a minimal Python script to download and cache the Python standard library. This is also useful to check the server is running correctly. Usage of `jsr:@pydantic/mcp-run-python` with PydanticAI is described in the [client](../client/#mcp-stdio-server) documentation. ## Direct Usage As well as using this server with PydanticAI, it can be connected to other MCP clients. For clarity, in this example we connect directly using the [Python MCP client](https://github.com/modelcontextprotocol/python-sdk). mcp_run_python.py ```python from mcp import ClientSession, StdioServerParameters from mcp.client.stdio import stdio_client code = """ import numpy a = numpy.array([1, 2, 3]) print(a) a """ server_params = StdioServerParameters( command='deno', args=[ 'run', '-N', '-R=node_modules', '-W=node_modules', '--node-modules-dir=auto', 'jsr:@pydantic/mcp-run-python', 'stdio', ], ) async def main(): async with stdio_client(server_params) as (read, write): async with ClientSession(read, write) as session: await session.initialize() tools = await session.list_tools() print(len(tools.tools)) #> 1 print(repr(tools.tools[0].name)) #> 'run_python_code' print(repr(tools.tools[0].inputSchema)) """ {'type': 'object', 'properties': {'python_code': {'type': 'string', 'description': 'Python code to run'}}, 'required': ['python_code'], 'additionalProperties': False, '$schema': 'http://json-schema.org/draft-07/schema#'} """ result = await session.call_tool('run_python_code', {'python_code': code}) print(result.content[0].text) """ success ["numpy"] [1 2 3] [ 1, 2, 3 ] """ ``` If an exception occurs, `status` will be `install-error` or `run-error` and `return_value` will be replaced by `error` which will include the traceback and exception message. ## Dependencies Dependencies are installed when code is run. Dependencies can be defined in one of two ways: ### Inferred from imports If there's no metadata, dependencies are inferred from imports in the code, as shown in the example [above](#direct-usage). ### Inline script metadata As introduced in PEP 723, explained [here](https://packaging.python.org/en/latest/specifications/inline-script-metadata/#inline-script-metadata), and popularized by [uv](https://docs.astral.sh/uv/guides/scripts/#declaring-script-dependencies) — dependencies can be defined in a comment at the top of the file. This allows use of dependencies that aren't imported in the code, and is more explicit. inline_script_metadata.py ```py from mcp import ClientSession from mcp.client.stdio import stdio_client # using `server_params` from the above example. from mcp_run_python import server_params code = """\ # /// script # dependencies = ["pydantic", "email-validator"] # /// import pydantic class Model(pydantic.BaseModel): email: pydantic.EmailStr print(Model(email='hello@pydantic.dev')) """ async def main(): async with stdio_client(server_params) as (read, write): async with ClientSession(read, write) as session: await session.initialize() result = await session.call_tool('run_python_code', {'python_code': code}) print(result.content[0].text) """ success ["pydantic","email-validator"] email='hello@pydantic.dev' """ ``` It also allows versions to be pinned for non-binary packages (Pyodide only supports a single version for the binary packages it supports, like `pydantic` and `numpy`). E.g. you could set the dependencies to ```python # /// script # dependencies = ["rich<13"] # /// ``` ## Logging MCP Run Python supports emitting stdout and stderr from the python execution as [MCP logging messages](https://github.com/modelcontextprotocol/specification/blob/eb4abdf2bb91e0d5afd94510741eadd416982350/docs/specification/draft/server/utilities/logging.md?plain=1). For logs to be emitted you must set the logging level when connecting to the server. By default, the log level is set to the highest level, `emergency`. Currently, it's not possible to demonstrate this due to a bug in the Python MCP Client, see [modelcontextprotocol/python-sdk#201](https://github.com/modelcontextprotocol/python-sdk/issues/201#issuecomment-2727663121). # Server PydanticAI models can also be used within MCP Servers. ## MCP Server Here's a simple example of a [Python MCP server](https://github.com/modelcontextprotocol/python-sdk) using PydanticAI within a tool call: mcp_server.py ```py from mcp.server.fastmcp import FastMCP from pydantic_ai import Agent server = FastMCP('PydanticAI Server') server_agent = Agent( 'anthropic:claude-3-5-haiku-latest', system_prompt='always reply in rhyme' ) @server.tool() async def poet(theme: str) -> str: """Poem generator""" r = await server_agent.run(f'write a poem about {theme}') return r.output if __name__ == '__main__': server.run() ``` ## Simple client This server can be queried with any MCP client. Here is an example using the Python SDK directly: mcp_client.py ```py import asyncio import os from mcp import ClientSession, StdioServerParameters from mcp.client.stdio import stdio_client async def client(): server_params = StdioServerParameters( command='python', args=['mcp_server.py'], env=os.environ ) async with stdio_client(server_params) as (read, write): async with ClientSession(read, write) as session: await session.initialize() result = await session.call_tool('poet', {'theme': 'socks'}) print(result.content[0].text) """ Oh, socks, those garments soft and sweet, That nestle softly 'round our feet, From cotton, wool, or blended thread, They keep our toes from feeling dread. """ if __name__ == '__main__': asyncio.run(client()) ``` ## MCP Sampling What is MCP Sampling? See the [MCP client docs](../client/#mcp-sampling) for details of what MCP sampling is, and how you can support it when using Pydantic AI as an MCP client. When Pydantic AI agents are used within MCP servers, they can use sampling via MCPSamplingModel. We can extend the above example to use sampling so instead of connecting directly to the LLM, the agent calls back through the MCP client to make LLM calls. mcp_server_sampling.py ```py from mcp.server.fastmcp import Context, FastMCP from pydantic_ai import Agent from pydantic_ai.models.mcp_sampling import MCPSamplingModel server = FastMCP('PydanticAI Server with sampling') server_agent = Agent(system_prompt='always reply in rhyme') @server.tool() async def poet(ctx: Context, theme: str) -> str: """Poem generator""" r = await server_agent.run(f'write a poem about {theme}', model=MCPSamplingModel(session=ctx.session)) return r.output if __name__ == '__main__': server.run() # run the server over stdio ``` The [above](#simple-client) client does not support sampling, so if you tried to use it with this server you'd get an error. The simplest way to support sampling in an MCP client is to [use](../client/#mcp-sampling) a Pydantic AI agent as the client, but if you wanted to support sampling with the vanilla MCP SDK, you could do so like this: mcp_client_sampling.py ```py import asyncio from typing import Any from mcp import ClientSession, StdioServerParameters from mcp.client.stdio import stdio_client from mcp.shared.context import RequestContext from mcp.types import CreateMessageRequestParams, CreateMessageResult, ErrorData, TextContent async def sampling_callback( context: RequestContext[ClientSession, Any], params: CreateMessageRequestParams ) -> CreateMessageResult | ErrorData: print('sampling system prompt:', params.systemPrompt) #> sampling system prompt: always reply in rhyme print('sampling messages:', params.messages) """ sampling messages: [ SamplingMessage( role='user', content=TextContent( type='text', text='write a poem about socks', annotations=None ), ) ] """ # TODO get the response content by calling an LLM... response_content = 'Socks for a fox.' return CreateMessageResult( role='assistant', content=TextContent(type='text', text=response_content), model='fictional-llm', ) async def client(): server_params = StdioServerParameters(command='python', args=['mcp_server_sampling.py']) async with stdio_client(server_params) as (read, write): async with ClientSession(read, write, sampling_callback=sampling_callback) as session: await session.initialize() result = await session.call_tool('poet', {'theme': 'socks'}) print(result.content[0].text) #> Socks for a fox. if __name__ == '__main__': asyncio.run(client()) ``` *(This example is complete, it can be run "as is" with Python 3.10+)* # Optional # Command Line Interface (CLI) **PydanticAI** comes with a CLI, `clai` (pronounced "clay") which you can use to interact with various LLMs from the command line. It provides a convenient way to chat with language models and quickly get answers right in the terminal. We originally developed this CLI for our own use, but found ourselves using it so frequently that we decided to share it as part of the PydanticAI package. We plan to continue adding new features, such as interaction with MCP servers, access to tools, and more. ## Usage You'll need to set an environment variable depending on the provider you intend to use. E.g. if you're using OpenAI, set the `OPENAI_API_KEY` environment variable: ```bash export OPENAI_API_KEY='your-api-key-here' ``` Then with [`uvx`](https://docs.astral.sh/uv/guides/tools/), run: ```bash uvx clai ``` Or to install `clai` globally [with `uv`](https://docs.astral.sh/uv/guides/tools/#installing-tools), run: ```bash uv tool install clai ... clai ``` Or with `pip`, run: ```bash pip install clai ... clai ``` Either way, running `clai` will start an interactive session where you can chat with the AI model. Special commands available in interactive mode: - `/exit`: Exit the session - `/markdown`: Show the last response in markdown format - `/multiline`: Toggle multiline input mode (use Ctrl+D to submit) ### Help To get help on the CLI, use the `--help` flag: ```bash uvx clai --help ``` ### Choose a model You can specify which model to use with the `--model` flag: ```bash uvx clai --model anthropic:claude-3-7-sonnet-latest ``` (a full list of models available can be printed with `uvx clai --list-models`) ### Custom Agents You can specify a custom agent using the `--agent` flag with a module path and variable name: custom_agent.py ```python from pydantic_ai import Agent agent = Agent('openai:gpt-4o', instructions='You always respond in Italian.') ``` Then run: ```bash uvx clai --agent custom_agent:agent "What's the weather today?" ``` The format must be `module:variable` where: - `module` is the importable Python module path - `variable` is the name of the Agent instance in that module Additionally, you can directly launch CLI mode from an `Agent` instance using `Agent.to_cli_sync()`: agent_to_cli_sync.py ```python from pydantic_ai import Agent agent = Agent('openai:gpt-4o', instructions='You always respond in Italian.') agent.to_cli_sync() ``` You can also use the async interface with `Agent.to_cli()`: agent_to_cli.py ```python from pydantic_ai import Agent agent = Agent('openai:gpt-4o', instructions='You always respond in Italian.') async def main(): await agent.to_cli() ``` *(You'll need to add `asyncio.run(main())` to run `main`)* # Debugging and Monitoring Applications that use LLMs have some challenges that are well known and understood: LLMs are **slow**, **unreliable** and **expensive**. These applications also have some challenges that most developers have encountered much less often: LLMs are **fickle** and **non-deterministic**. Subtle changes in a prompt can completely change a model's performance, and there's no `EXPLAIN` query you can run to understand why. Warning From a software engineers point of view, you can think of LLMs as the worst database you've ever heard of, but worse. If LLMs weren't so bloody useful, we'd never touch them. To build successful applications with LLMs, we need new tools to understand both model performance, and the behavior of applications that rely on them. LLM Observability tools that just let you understand how your model is performing are useless: making API calls to an LLM is easy, it's building that into an application that's hard. ## Pydantic Logfire [Pydantic Logfire](https://pydantic.dev/logfire) is an observability platform developed by the team who created and maintain Pydantic and PydanticAI. Logfire aims to let you understand your entire application: Gen AI, classic predictive AI, HTTP traffic, database queries and everything else a modern application needs, all using OpenTelemetry. Pydantic Logfire is a commercial product Logfire is a commercially supported, hosted platform with an extremely generous and perpetual [free tier](https://pydantic.dev/pricing/). You can sign up and start using Logfire in a couple of minutes. Logfire can also be self-hosted on the enterprise tier. PydanticAI has built-in (but optional) support for Logfire. That means if the `logfire` package is installed and configured and agent instrumentation is enabled then detailed information about agent runs is sent to Logfire. Otherwise there's virtually no overhead and nothing is sent. Here's an example showing details of running the [Weather Agent](../examples/weather-agent/) in Logfire: A trace is generated for the agent run, and spans are emitted for each model request and tool call. ## Using Logfire To use Logfire, you'll need a Logfire [account](https://logfire.pydantic.dev), and the Logfire Python SDK installed: ```bash pip install "pydantic-ai[logfire]" ``` ```bash uv add "pydantic-ai[logfire]" ``` Then authenticate your local environment with Logfire: ```bash logfire auth ``` ```bash uv run logfire auth ``` And configure a project to send data to: ```bash logfire projects new ``` ```bash uv run logfire projects new ``` (Or use an existing project with `logfire projects use`) This will write to a `.logfire` directory in the current working directory, which the Logfire SDK will use for configuration at run time. With that, you can start using Logfire to instrument PydanticAI code: instrument_pydantic_ai.py ```python import logfire from pydantic_ai import Agent logfire.configure() # (1)! logfire.instrument_pydantic_ai() # (2)! agent = Agent('openai:gpt-4o', instructions='Be concise, reply with one sentence.') result = agent.run_sync('Where does "hello world" come from?') # (3)! print(result.output) """ The first known use of "hello, world" was in a 1974 textbook about the C programming language. """ ``` 1. logfire.configure() configures the SDK, by default it will find the write token from the `.logfire` directory, but you can also pass a token directly. 1. logfire.instrument_pydantic_ai() enables instrumentation of PydanticAI. 1. Since we've enabled instrumentation, a trace will be generated for each run, with spans emitted for models calls and tool function execution *(This example is complete, it can be run "as is")* Which will display in Logfire thus: The [logfire documentation](https://logfire.pydantic.dev/docs/) has more details on how to use Logfire, including how to instrument other libraries like [HTTPX](https://logfire.pydantic.dev/docs/integrations/http-clients/httpx/) and [FastAPI](https://logfire.pydantic.dev/docs/integrations/web-frameworks/fastapi/). Since Logfire is built on [OpenTelemetry](https://opentelemetry.io/), you can use the Logfire Python SDK to send data to any OpenTelemetry collector, see [below](#using-opentelemetry). ### Debugging To demonstrate how Logfire can let you visualise the flow of a PydanticAI run, here's the view you get from Logfire while running the [chat app examples](../examples/chat-app/): ### Monitoring Performance We can also query data with SQL in Logfire to monitor the performance of an application. Here's a real world example of using Logfire to monitor PydanticAI runs inside Logfire itself: ### Monitoring HTTP Requests "F\*\*k you, show me the prompt." As per Hamel Husain's influential 2024 blog post ["Fuck You, Show Me The Prompt."](https://hamel.dev/blog/posts/prompt/) (bear with the capitalization, the point is valid), it's often useful to be able to view the raw HTTP requests and responses made to model providers. To observe raw HTTP requests made to model providers, you can use `logfire`'s [HTTPX instrumentation](https://logfire.pydantic.dev/docs/integrations/http-clients/httpx/) since all provider SDKs use the [HTTPX](https://www.python-httpx.org/) library internally. with_logfire_instrument_httpx.py ```py import logfire from pydantic_ai import Agent logfire.configure() logfire.instrument_pydantic_ai() logfire.instrument_httpx(capture_all=True) # (1)! agent = Agent('openai:gpt-4o') result = agent.run_sync('What is the capital of France?') print(result.output) #> Paris ``` 1. See the logfire.instrument_httpx docs more details, `capture_all=True` means both headers and body are captured for both the request and response. without_logfire_instrument_httpx.py ```py import logfire from pydantic_ai import Agent logfire.configure() logfire.instrument_pydantic_ai() agent = Agent('openai:gpt-4o') result = agent.run_sync('What is the capital of France?') print(result.output) #> Paris ``` ## Using OpenTelemetry PydanticAI's instrumentation uses [OpenTelemetry](https://opentelemetry.io/) (OTel), which Logfire is based on. This means you can debug and monitor PydanticAI with any OpenTelemetry backend. PydanticAI follows the [OpenTelemetry Semantic Conventions for Generative AI systems](https://opentelemetry.io/docs/specs/semconv/gen-ai/), so while we think you'll have the best experience using the Logfire platform , you should be able to use any OTel service with GenAI support. ### Logfire with an alternative OTel backend You can use the Logfire SDK completely freely and send the data to any OpenTelemetry backend. Here's an example of configuring the Logfire library to send data to the excellent [otel-tui](https://github.com/ymtdzzz/otel-tui) — an open source terminal based OTel backend and viewer (no association with Pydantic). Run `otel-tui` with docker (see [the otel-tui readme](https://github.com/ymtdzzz/otel-tui) for more instructions): Terminal ```text docker run --rm -it -p 4318:4318 --name otel-tui ymtdzzz/otel-tui:latest ``` then run, otel_tui.py ```python import os import logfire from pydantic_ai import Agent os.environ['OTEL_EXPORTER_OTLP_ENDPOINT'] = 'http://localhost:4318' # (1)! logfire.configure(send_to_logfire=False) # (2)! logfire.instrument_pydantic_ai() logfire.instrument_httpx(capture_all=True) agent = Agent('openai:gpt-4o') result = agent.run_sync('What is the capital of France?') print(result.output) #> Paris ``` 1. Set the `OTEL_EXPORTER_OTLP_ENDPOINT` environment variable to the URL of your OpenTelemetry backend. If you're using a backend that requires authentication, you may need to set [other environment variables](https://opentelemetry.io/docs/languages/sdk-configuration/otlp-exporter/). Of course, these can also be set outside the process, e.g. with `export OTEL_EXPORTER_OTLP_ENDPOINT=http://localhost:4318`. 1. We configure Logfire to disable sending data to the Logfire OTel backend itself. If you removed `send_to_logfire=False`, data would be sent to both Logfire and your OpenTelemetry backend. Running the above code will send tracing data to `otel-tui`, which will display like this: Running the [weather agent](../examples/weather-agent/) example connected to `otel-tui` shows how it can be used to visualise a more complex trace: For more information on using the Logfire SDK to send data to alternative backends, see [the Logfire documentation](https://logfire.pydantic.dev/docs/how-to-guides/alternative-backends/). ### OTel without Logfire You can also emit OpenTelemetry data from PydanticAI without using Logfire at all. To do this, you'll need to install and configure the OpenTelemetry packages you need. To run the following examples, use Terminal ```text uv run \ --with 'pydantic-ai-slim[openai]' \ --with opentelemetry-sdk \ --with opentelemetry-exporter-otlp \ raw_otel.py ``` raw_otel.py ```python import os from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter from opentelemetry.sdk.trace import TracerProvider from opentelemetry.sdk.trace.export import BatchSpanProcessor from opentelemetry.trace import set_tracer_provider from pydantic_ai.agent import Agent os.environ['OTEL_EXPORTER_OTLP_ENDPOINT'] = 'http://localhost:4318' exporter = OTLPSpanExporter() span_processor = BatchSpanProcessor(exporter) tracer_provider = TracerProvider() tracer_provider.add_span_processor(span_processor) set_tracer_provider(tracer_provider) Agent.instrument_all() agent = Agent('openai:gpt-4o') result = agent.run_sync('What is the capital of France?') print(result.output) #> Paris ``` ### Alternative Observability backends Because Pydantic AI uses OpenTelemetry for observability, you can easily configure it to send data to any OpenTelemetry-compatible backend, not just our observability platform [Pydantic Logfire](#pydantic-logfire). The following providers have dedicated documentation on Pydantic AI: - [Langfuse](https://langfuse.com/docs/integrations/pydantic-ai) - [W&B Weave](https://weave-docs.wandb.ai/guides/integrations/pydantic_ai/) - [Arize](https://arize.com/docs/ax/observe/tracing-integrations-auto/pydantic-ai) - [Openlayer](https://www.openlayer.com/docs/integrations/pydantic-ai) - [OpenLIT](https://docs.openlit.io/latest/integrations/pydantic) - [LangWatch](https://docs.langwatch.ai/integration/python/integrations/pydantic-ai) - [Patronus AI](https://docs.patronus.ai/docs/percival/pydantic) - [Opik](https://www.comet.com/docs/opik/tracing/integrations/pydantic-ai) - [mlflow](https://mlflow.org/docs/latest/genai/tracing/integrations/listing/pydantic_ai) ## Advanced usage ### Configuring data format PydanticAI follows the [OpenTelemetry Semantic Conventions for Generative AI systems](https://opentelemetry.io/docs/specs/semconv/gen-ai/), with one caveat. The semantic conventions specify that messages should be captured as individual events (logs) that are children of the request span. By default, PydanticAI instead collects these events into a JSON array which is set as a single large attribute called `events` on the request span. To change this, use `event_mode='logs'`: instrumentation_settings_event_mode.py ```python import logfire from pydantic_ai import Agent logfire.configure() logfire.instrument_pydantic_ai(event_mode='logs') agent = Agent('openai:gpt-4o') result = agent.run_sync('What is the capital of France?') print(result.output) #> Paris ``` For now, this won't look as good in the Logfire UI, but we're working on it. If you have very long conversations, the `events` span attribute may be truncated. Using `event_mode='logs'` will help avoid this issue. Note that the OpenTelemetry Semantic Conventions are still experimental and are likely to change. ### Setting OpenTelemetry SDK providers By default, the global `TracerProvider` and `EventLoggerProvider` are used. These are set automatically by `logfire.configure()`. They can also be set by the `set_tracer_provider` and `set_event_logger_provider` functions in the OpenTelemetry Python SDK. You can set custom providers with InstrumentationSettings. instrumentation_settings_providers.py ```python from opentelemetry.sdk._events import EventLoggerProvider from opentelemetry.sdk.trace import TracerProvider from pydantic_ai.agent import Agent, InstrumentationSettings instrumentation_settings = InstrumentationSettings( tracer_provider=TracerProvider(), event_logger_provider=EventLoggerProvider(), ) agent = Agent('gpt-4o', instrument=instrumentation_settings) # or to instrument all agents: Agent.instrument_all(instrumentation_settings) ``` ### Instrumenting a specific `Model` instrumented_model_example.py ```python from pydantic_ai import Agent from pydantic_ai.models.instrumented import InstrumentationSettings, InstrumentedModel settings = InstrumentationSettings() model = InstrumentedModel('gpt-4o', settings) agent = Agent(model) ``` ### Excluding binary content excluding_binary_content.py ```python from pydantic_ai.agent import Agent, InstrumentationSettings instrumentation_settings = InstrumentationSettings(include_binary_content=False) agent = Agent('gpt-4o', instrument=instrumentation_settings) # or to instrument all agents: Agent.instrument_all(instrumentation_settings) ``` ### Excluding prompts and completions For privacy and security reasons, you may want to monitor your agent's behavior and performance without exposing sensitive user data or proprietary prompts in your observability platform. PydanticAI allows you to exclude the actual content from instrumentation events while preserving the structural information needed for debugging and monitoring. When `include_content=False` is set, PydanticAI will exclude sensitive content from OpenTelemetry events, including user prompts and model completions, tool call arguments and responses, and any other message content. excluding_sensitive_content.py ```python from pydantic_ai.agent import Agent from pydantic_ai.models.instrumented import InstrumentationSettings instrumentation_settings = InstrumentationSettings(include_content=False) agent = Agent('gpt-4o', instrument=instrumentation_settings) # or to instrument all agents: Agent.instrument_all(instrumentation_settings) ``` This setting is particularly useful in production environments where compliance requirements or data sensitivity concerns make it necessary to limit what content is sent to your observability platform. # Unit testing Writing unit tests for PydanticAI code is just like unit tests for any other Python code. Because for the most part they're nothing new, we have pretty well established tools and patterns for writing and running these kinds of tests. Unless you're really sure you know better, you'll probably want to follow roughly this strategy: - Use [`pytest`](https://docs.pytest.org/en/stable/) as your test harness - If you find yourself typing out long assertions, use [inline-snapshot](https://15r10nk.github.io/inline-snapshot/latest/) - Similarly, [dirty-equals](https://dirty-equals.helpmanual.io/latest/) can be useful for comparing large data structures - Use TestModel or FunctionModel in place of your actual model to avoid the usage, latency and variability of real LLM calls - Use Agent.override to replace your model inside your application logic - Set ALLOW_MODEL_REQUESTS=False globally to block any requests from being made to non-test models accidentally ### Unit testing with `TestModel` The simplest and fastest way to exercise most of your application code is using TestModel, this will (by default) call all tools in the agent, then return either plain text or a structured response depending on the return type of the agent. `TestModel` is not magic The "clever" (but not too clever) part of `TestModel` is that it will attempt to generate valid structured data for [function tools](../tools/) and [output types](../output/#structured-output) based on the schema of the registered tools. There's no ML or AI in `TestModel`, it's just plain old procedural Python code that tries to generate data that satisfies the JSON schema of a tool. The resulting data won't look pretty or relevant, but it should pass Pydantic's validation in most cases. If you want something more sophisticated, use FunctionModel and write your own data generation logic. Let's write unit tests for the following application code: weather_app.py ```python import asyncio from datetime import date from pydantic_ai import Agent, RunContext from fake_database import DatabaseConn # (1)! from weather_service import WeatherService # (2)! weather_agent = Agent( 'openai:gpt-4o', deps_type=WeatherService, system_prompt='Providing a weather forecast at the locations the user provides.', ) @weather_agent.tool def weather_forecast( ctx: RunContext[WeatherService], location: str, forecast_date: date ) -> str: if forecast_date < date.today(): # (3)! return ctx.deps.get_historic_weather(location, forecast_date) else: return ctx.deps.get_forecast(location, forecast_date) async def run_weather_forecast( # (4)! user_prompts: list[tuple[str, int]], conn: DatabaseConn ): """Run weather forecast for a list of user prompts and save.""" async with WeatherService() as weather_service: async def run_forecast(prompt: str, user_id: int): result = await weather_agent.run(prompt, deps=weather_service) await conn.store_forecast(user_id, result.output) # run all prompts in parallel await asyncio.gather( *(run_forecast(prompt, user_id) for (prompt, user_id) in user_prompts) ) ``` 1. `DatabaseConn` is a class that holds a database connection 1. `WeatherService` has methods to get weather forecasts and historic data about the weather 1. We need to call a different endpoint depending on whether the date is in the past or the future, you'll see why this nuance is important below 1. This function is the code we want to test, together with the agent it uses Here we have a function that takes a list of `(user_prompt, user_id)` tuples, gets a weather forecast for each prompt, and stores the result in the database. **We want to test this code without having to mock certain objects or modify our code so we can pass test objects in.** Here's how we would write tests using TestModel: test_weather_app.py ```python from datetime import timezone import pytest from dirty_equals import IsNow, IsStr from pydantic_ai import models, capture_run_messages from pydantic_ai.models.test import TestModel from pydantic_ai.messages import ( ModelResponse, SystemPromptPart, TextPart, ToolCallPart, ToolReturnPart, UserPromptPart, ModelRequest, ) from pydantic_ai.usage import Usage from fake_database import DatabaseConn from weather_app import run_weather_forecast, weather_agent pytestmark = pytest.mark.anyio # (1)! models.ALLOW_MODEL_REQUESTS = False # (2)! async def test_forecast(): conn = DatabaseConn() user_id = 1 with capture_run_messages() as messages: with weather_agent.override(model=TestModel()): # (3)! prompt = 'What will the weather be like in London on 2024-11-28?' await run_weather_forecast([(prompt, user_id)], conn) # (4)! forecast = await conn.get_forecast(user_id) assert forecast == '{"weather_forecast":"Sunny with a chance of rain"}' # (5)! assert messages == [ # (6)! ModelRequest( parts=[ SystemPromptPart( content='Providing a weather forecast at the locations the user provides.', timestamp=IsNow(tz=timezone.utc), ), UserPromptPart( content='What will the weather be like in London on 2024-11-28?', timestamp=IsNow(tz=timezone.utc), # (7)! ), ] ), ModelResponse( parts=[ ToolCallPart( tool_name='weather_forecast', args={ 'location': 'a', 'forecast_date': '2024-01-01', # (8)! }, tool_call_id=IsStr(), ) ], usage=Usage( requests=1, request_tokens=71, response_tokens=7, total_tokens=78, details=None, ), model_name='test', timestamp=IsNow(tz=timezone.utc), ), ModelRequest( parts=[ ToolReturnPart( tool_name='weather_forecast', content='Sunny with a chance of rain', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), ), ], ), ModelResponse( parts=[ TextPart( content='{"weather_forecast":"Sunny with a chance of rain"}', ) ], usage=Usage( requests=1, request_tokens=77, response_tokens=16, total_tokens=93, details=None, ), model_name='test', timestamp=IsNow(tz=timezone.utc), ), ] ``` 1. We're using [anyio](https://anyio.readthedocs.io/en/stable/) to run async tests. 1. This is a safety measure to make sure we don't accidentally make real requests to the LLM while testing, see ALLOW_MODEL_REQUESTS for more details. 1. We're using Agent.override to replace the agent's model with TestModel, the nice thing about `override` is that we can replace the model inside agent without needing access to the agent `run*` methods call site. 1. Now we call the function we want to test inside the `override` context manager. 1. But default, `TestModel` will return a JSON string summarising the tools calls made, and what was returned. If you wanted to customise the response to something more closely aligned with the domain, you could add custom_output_text='Sunny' when defining `TestModel`. 1. So far we don't actually know which tools were called and with which values, we can use capture_run_messages to inspect messages from the most recent run and assert the exchange between the agent and the model occurred as expected. 1. The IsNow helper allows us to use declarative asserts even with data which will contain timestamps that change over time. 1. `TestModel` isn't doing anything clever to extract values from the prompt, so these values are hardcoded. ### Unit testing with `FunctionModel` The above tests are a great start, but careful readers will notice that the `WeatherService.get_forecast` is never called since `TestModel` calls `weather_forecast` with a date in the past. To fully exercise `weather_forecast`, we need to use FunctionModel to customise how the tools is called. Here's an example of using `FunctionModel` to test the `weather_forecast` tool with custom inputs test_weather_app2.py ```python import re import pytest from pydantic_ai import models from pydantic_ai.messages import ( ModelMessage, ModelResponse, TextPart, ToolCallPart, ) from pydantic_ai.models.function import AgentInfo, FunctionModel from fake_database import DatabaseConn from weather_app import run_weather_forecast, weather_agent pytestmark = pytest.mark.anyio models.ALLOW_MODEL_REQUESTS = False def call_weather_forecast( # (1)! messages: list[ModelMessage], info: AgentInfo ) -> ModelResponse: if len(messages) == 1: # first call, call the weather forecast tool user_prompt = messages[0].parts[-1] m = re.search(r'\d{4}-\d{2}-\d{2}', user_prompt.content) assert m is not None args = {'location': 'London', 'forecast_date': m.group()} # (2)! return ModelResponse(parts=[ToolCallPart('weather_forecast', args)]) else: # second call, return the forecast msg = messages[-1].parts[0] assert msg.part_kind == 'tool-return' return ModelResponse(parts=[TextPart(f'The forecast is: {msg.content}')]) async def test_forecast_future(): conn = DatabaseConn() user_id = 1 with weather_agent.override(model=FunctionModel(call_weather_forecast)): # (3)! prompt = 'What will the weather be like in London on 2032-01-01?' await run_weather_forecast([(prompt, user_id)], conn) forecast = await conn.get_forecast(user_id) assert forecast == 'The forecast is: Rainy with a chance of sun' ``` 1. We define a function `call_weather_forecast` that will be called by `FunctionModel` in place of the LLM, this function has access to the list of ModelMessages that make up the run, and AgentInfo which contains information about the agent and the function tools and return tools. 1. Our function is slightly intelligent in that it tries to extract a date from the prompt, but just hard codes the location. 1. We use FunctionModel to replace the agent's model with our custom function. ### Overriding model via pytest fixtures If you're writing lots of tests that all require model to be overridden, you can use [pytest fixtures](https://docs.pytest.org/en/6.2.x/fixture.html) to override the model with TestModel or FunctionModel in a reusable way. Here's an example of a fixture that overrides the model with `TestModel`: test_agent.py ```python import pytest from weather_app import weather_agent from pydantic_ai.models.test import TestModel @pytest.fixture def override_weather_agent(): with weather_agent.override(model=TestModel()): yield async def test_forecast(override_weather_agent: None): ... # test code here ``` # Examples Examples of how to use PydanticAI and what it can do. ## Usage These examples are distributed with `pydantic-ai` so you can run them either by cloning the [pydantic-ai repo](https://github.com/pydantic/pydantic-ai) or by simply installing `pydantic-ai` from PyPI with `pip` or `uv`. ### Installing required dependencies Either way you'll need to install extra dependencies to run some examples, you just need to install the `examples` optional dependency group. If you've installed `pydantic-ai` via pip/uv, you can install the extra dependencies with: ```bash pip install "pydantic-ai[examples]" ``` ```bash uv add "pydantic-ai[examples]" ``` If you clone the repo, you should instead use `uv sync --extra examples` to install extra dependencies. ### Setting model environment variables These examples will need you to set up authentication with one or more of the LLMs, see the [model configuration](../models/) docs for details on how to do this. TL;DR: in most cases you'll need to set one of the following environment variables: ```bash export OPENAI_API_KEY=your-api-key ``` ```bash export GEMINI_API_KEY=your-api-key ``` ### Running Examples To run the examples (this will work whether you installed `pydantic_ai`, or cloned the repo), run: ```bash python -m pydantic_ai_examples. ``` ```bash uv run -m pydantic_ai_examples. ``` For examples, to run the very simple [`pydantic_model`](pydantic-model/) example: ```bash python -m pydantic_ai_examples.pydantic_model ``` ```bash uv run -m pydantic_ai_examples.pydantic_model ``` If you like one-liners and you're using uv, you can run a pydantic-ai example with zero setup: ```bash OPENAI_API_KEY='your-api-key' \ uv run --with "pydantic-ai[examples]" \ -m pydantic_ai_examples.pydantic_model ``` ______________________________________________________________________ You'll probably want to edit examples in addition to just running them. You can copy the examples to a new directory with: ```bash python -m pydantic_ai_examples --copy-to examples/ ``` ```bash uv run -m pydantic_ai_examples --copy-to examples/ ``` Small but complete example of using PydanticAI to build a support agent for a bank. Demonstrates: - [dynamic system prompt](../../agents/#system-prompts) - [structured `output_type`](../../output/#structured-output) - [tools](../../tools/) ## Running the Example With [dependencies installed and environment variables set](../#usage), run: ```bash python -m pydantic_ai_examples.bank_support ``` ```bash uv run -m pydantic_ai_examples.bank_support ``` (or `PYDANTIC_AI_MODEL=gemini-1.5-flash ...`) ## Example Code [bank_support.py](https://github.com/pydantic/pydantic-ai/blob/main/examples/pydantic_ai_examples/bank_support.py) ```py """Small but complete example of using PydanticAI to build a support agent for a bank. Run with: uv run -m pydantic_ai_examples.bank_support """ from dataclasses import dataclass from pydantic import BaseModel, Field from pydantic_ai import Agent, RunContext class DatabaseConn: """This is a fake database for example purposes. In reality, you'd be connecting to an external database (e.g. PostgreSQL) to get information about customers. """ @classmethod async def customer_name(cls, *, id: int) -> str | None: if id == 123: return 'John' @classmethod async def customer_balance(cls, *, id: int, include_pending: bool) -> float: if id == 123: if include_pending: return 123.45 else: return 100.00 else: raise ValueError('Customer not found') @dataclass class SupportDependencies: customer_id: int db: DatabaseConn class SupportOutput(BaseModel): support_advice: str = Field(description='Advice returned to the customer') block_card: bool = Field(description='Whether to block their card or not') risk: int = Field(description='Risk level of query', ge=0, le=10) support_agent = Agent( 'openai:gpt-4o', deps_type=SupportDependencies, output_type=SupportOutput, system_prompt=( 'You are a support agent in our bank, give the ' 'customer support and judge the risk level of their query. ' "Reply using the customer's name." ), ) @support_agent.system_prompt async def add_customer_name(ctx: RunContext[SupportDependencies]) -> str: customer_name = await ctx.deps.db.customer_name(id=ctx.deps.customer_id) return f"The customer's name is {customer_name!r}" @support_agent.tool async def customer_balance( ctx: RunContext[SupportDependencies], include_pending: bool ) -> str: """Returns the customer's current account balance.""" balance = await ctx.deps.db.customer_balance( id=ctx.deps.customer_id, include_pending=include_pending, ) return f'${balance:.2f}' if __name__ == '__main__': deps = SupportDependencies(customer_id=123, db=DatabaseConn()) result = support_agent.run_sync('What is my balance?', deps=deps) print(result.output) """ support_advice='Hello John, your current account balance, including pending transactions, is $123.45.' block_card=False risk=1 """ result = support_agent.run_sync('I just lost my card!', deps=deps) print(result.output) """ support_advice="I'm sorry to hear that, John. We are temporarily blocking your card to prevent unauthorized transactions." block_card=True risk=8 """ ``` # Chat App with FastAPI Simple chat app example build with FastAPI. Demonstrates: - [reusing chat history](../../message-history/) - [serializing messages](../../message-history/#accessing-messages-from-results) - [streaming responses](../../output/#streamed-results) This demonstrates storing chat history between requests and using it to give the model context for new responses. Most of the complex logic here is between `chat_app.py` which streams the response to the browser, and `chat_app.ts` which renders messages in the browser. ## Running the Example With [dependencies installed and environment variables set](../#usage), run: ```bash python -m pydantic_ai_examples.chat_app ``` ```bash uv run -m pydantic_ai_examples.chat_app ``` Then open the app at [localhost:8000](http://localhost:8000). ## Example Code Python code that runs the chat app: [chat_app.py](https://github.com/pydantic/pydantic-ai/blob/main/examples/pydantic_ai_examples/chat_app.py) ```py """Simple chat app example build with FastAPI. Run with: uv run -m pydantic_ai_examples.chat_app """ from __future__ import annotations as _annotations import asyncio import json import sqlite3 from collections.abc import AsyncIterator from concurrent.futures.thread import ThreadPoolExecutor from contextlib import asynccontextmanager from dataclasses import dataclass from datetime import datetime, timezone from functools import partial from pathlib import Path from typing import Annotated, Any, Callable, Literal, TypeVar import fastapi import logfire from fastapi import Depends, Request from fastapi.responses import FileResponse, Response, StreamingResponse from typing_extensions import LiteralString, ParamSpec, TypedDict from pydantic_ai import Agent from pydantic_ai.exceptions import UnexpectedModelBehavior from pydantic_ai.messages import ( ModelMessage, ModelMessagesTypeAdapter, ModelRequest, ModelResponse, TextPart, UserPromptPart, ) # 'if-token-present' means nothing will be sent (and the example will work) if you don't have logfire configured logfire.configure(send_to_logfire='if-token-present') logfire.instrument_pydantic_ai() agent = Agent('openai:gpt-4o') THIS_DIR = Path(__file__).parent @asynccontextmanager async def lifespan(_app: fastapi.FastAPI): async with Database.connect() as db: yield {'db': db} app = fastapi.FastAPI(lifespan=lifespan) logfire.instrument_fastapi(app) @app.get('/') async def index() -> FileResponse: return FileResponse((THIS_DIR / 'chat_app.html'), media_type='text/html') @app.get('/chat_app.ts') async def main_ts() -> FileResponse: """Get the raw typescript code, it's compiled in the browser, forgive me.""" return FileResponse((THIS_DIR / 'chat_app.ts'), media_type='text/plain') async def get_db(request: Request) -> Database: return request.state.db @app.get('/chat/') async def get_chat(database: Database = Depends(get_db)) -> Response: msgs = await database.get_messages() return Response( b'\n'.join(json.dumps(to_chat_message(m)).encode('utf-8') for m in msgs), media_type='text/plain', ) class ChatMessage(TypedDict): """Format of messages sent to the browser.""" role: Literal['user', 'model'] timestamp: str content: str def to_chat_message(m: ModelMessage) -> ChatMessage: first_part = m.parts[0] if isinstance(m, ModelRequest): if isinstance(first_part, UserPromptPart): assert isinstance(first_part.content, str) return { 'role': 'user', 'timestamp': first_part.timestamp.isoformat(), 'content': first_part.content, } elif isinstance(m, ModelResponse): if isinstance(first_part, TextPart): return { 'role': 'model', 'timestamp': m.timestamp.isoformat(), 'content': first_part.content, } raise UnexpectedModelBehavior(f'Unexpected message type for chat app: {m}') @app.post('/chat/') async def post_chat( prompt: Annotated[str, fastapi.Form()], database: Database = Depends(get_db) ) -> StreamingResponse: async def stream_messages(): """Streams new line delimited JSON `Message`s to the client.""" # stream the user prompt so that can be displayed straight away yield ( json.dumps( { 'role': 'user', 'timestamp': datetime.now(tz=timezone.utc).isoformat(), 'content': prompt, } ).encode('utf-8') + b'\n' ) # get the chat history so far to pass as context to the agent messages = await database.get_messages() # run the agent with the user prompt and the chat history async with agent.run_stream(prompt, message_history=messages) as result: async for text in result.stream(debounce_by=0.01): # text here is a `str` and the frontend wants # JSON encoded ModelResponse, so we create one m = ModelResponse(parts=[TextPart(text)], timestamp=result.timestamp()) yield json.dumps(to_chat_message(m)).encode('utf-8') + b'\n' # add new messages (e.g. the user prompt and the agent response in this case) to the database await database.add_messages(result.new_messages_json()) return StreamingResponse(stream_messages(), media_type='text/plain') P = ParamSpec('P') R = TypeVar('R') @dataclass class Database: """Rudimentary database to store chat messages in SQLite. The SQLite standard library package is synchronous, so we use a thread pool executor to run queries asynchronously. """ con: sqlite3.Connection _loop: asyncio.AbstractEventLoop _executor: ThreadPoolExecutor @classmethod @asynccontextmanager async def connect( cls, file: Path = THIS_DIR / '.chat_app_messages.sqlite' ) -> AsyncIterator[Database]: with logfire.span('connect to DB'): loop = asyncio.get_event_loop() executor = ThreadPoolExecutor(max_workers=1) con = await loop.run_in_executor(executor, cls._connect, file) slf = cls(con, loop, executor) try: yield slf finally: await slf._asyncify(con.close) @staticmethod def _connect(file: Path) -> sqlite3.Connection: con = sqlite3.connect(str(file)) con = logfire.instrument_sqlite3(con) cur = con.cursor() cur.execute( 'CREATE TABLE IF NOT EXISTS messages (id INT PRIMARY KEY, message_list TEXT);' ) con.commit() return con async def add_messages(self, messages: bytes): await self._asyncify( self._execute, 'INSERT INTO messages (message_list) VALUES (?);', messages, commit=True, ) await self._asyncify(self.con.commit) async def get_messages(self) -> list[ModelMessage]: c = await self._asyncify( self._execute, 'SELECT message_list FROM messages order by id' ) rows = await self._asyncify(c.fetchall) messages: list[ModelMessage] = [] for row in rows: messages.extend(ModelMessagesTypeAdapter.validate_json(row[0])) return messages def _execute( self, sql: LiteralString, *args: Any, commit: bool = False ) -> sqlite3.Cursor: cur = self.con.cursor() cur.execute(sql, args) if commit: self.con.commit() return cur async def _asyncify( self, func: Callable[P, R], *args: P.args, **kwargs: P.kwargs ) -> R: return await self._loop.run_in_executor( # type: ignore self._executor, partial(func, **kwargs), *args, # type: ignore ) if __name__ == '__main__': import uvicorn uvicorn.run( 'pydantic_ai_examples.chat_app:app', reload=True, reload_dirs=[str(THIS_DIR)] ) ``` Simple HTML page to render the app: [chat_app.py](https://github.com/pydantic/pydantic-ai/blob/main/examples/pydantic_ai_examples/chat_app.py) ```py """Simple chat app example build with FastAPI. Run with: uv run -m pydantic_ai_examples.chat_app """ from __future__ import annotations as _annotations import asyncio import json import sqlite3 from collections.abc import AsyncIterator from concurrent.futures.thread import ThreadPoolExecutor from contextlib import asynccontextmanager from dataclasses import dataclass from datetime import datetime, timezone from functools import partial from pathlib import Path from typing import Annotated, Any, Callable, Literal, TypeVar import fastapi import logfire from fastapi import Depends, Request from fastapi.responses import FileResponse, Response, StreamingResponse from typing_extensions import LiteralString, ParamSpec, TypedDict from pydantic_ai import Agent from pydantic_ai.exceptions import UnexpectedModelBehavior from pydantic_ai.messages import ( ModelMessage, ModelMessagesTypeAdapter, ModelRequest, ModelResponse, TextPart, UserPromptPart, ) # 'if-token-present' means nothing will be sent (and the example will work) if you don't have logfire configured logfire.configure(send_to_logfire='if-token-present') logfire.instrument_pydantic_ai() agent = Agent('openai:gpt-4o') THIS_DIR = Path(__file__).parent @asynccontextmanager async def lifespan(_app: fastapi.FastAPI): async with Database.connect() as db: yield {'db': db} app = fastapi.FastAPI(lifespan=lifespan) logfire.instrument_fastapi(app) @app.get('/') async def index() -> FileResponse: return FileResponse((THIS_DIR / 'chat_app.html'), media_type='text/html') @app.get('/chat_app.ts') async def main_ts() -> FileResponse: """Get the raw typescript code, it's compiled in the browser, forgive me.""" return FileResponse((THIS_DIR / 'chat_app.ts'), media_type='text/plain') async def get_db(request: Request) -> Database: return request.state.db @app.get('/chat/') async def get_chat(database: Database = Depends(get_db)) -> Response: msgs = await database.get_messages() return Response( b'\n'.join(json.dumps(to_chat_message(m)).encode('utf-8') for m in msgs), media_type='text/plain', ) class ChatMessage(TypedDict): """Format of messages sent to the browser.""" role: Literal['user', 'model'] timestamp: str content: str def to_chat_message(m: ModelMessage) -> ChatMessage: first_part = m.parts[0] if isinstance(m, ModelRequest): if isinstance(first_part, UserPromptPart): assert isinstance(first_part.content, str) return { 'role': 'user', 'timestamp': first_part.timestamp.isoformat(), 'content': first_part.content, } elif isinstance(m, ModelResponse): if isinstance(first_part, TextPart): return { 'role': 'model', 'timestamp': m.timestamp.isoformat(), 'content': first_part.content, } raise UnexpectedModelBehavior(f'Unexpected message type for chat app: {m}') @app.post('/chat/') async def post_chat( prompt: Annotated[str, fastapi.Form()], database: Database = Depends(get_db) ) -> StreamingResponse: async def stream_messages(): """Streams new line delimited JSON `Message`s to the client.""" # stream the user prompt so that can be displayed straight away yield ( json.dumps( { 'role': 'user', 'timestamp': datetime.now(tz=timezone.utc).isoformat(), 'content': prompt, } ).encode('utf-8') + b'\n' ) # get the chat history so far to pass as context to the agent messages = await database.get_messages() # run the agent with the user prompt and the chat history async with agent.run_stream(prompt, message_history=messages) as result: async for text in result.stream(debounce_by=0.01): # text here is a `str` and the frontend wants # JSON encoded ModelResponse, so we create one m = ModelResponse(parts=[TextPart(text)], timestamp=result.timestamp()) yield json.dumps(to_chat_message(m)).encode('utf-8') + b'\n' # add new messages (e.g. the user prompt and the agent response in this case) to the database await database.add_messages(result.new_messages_json()) return StreamingResponse(stream_messages(), media_type='text/plain') P = ParamSpec('P') R = TypeVar('R') @dataclass class Database: """Rudimentary database to store chat messages in SQLite. The SQLite standard library package is synchronous, so we use a thread pool executor to run queries asynchronously. """ con: sqlite3.Connection _loop: asyncio.AbstractEventLoop _executor: ThreadPoolExecutor @classmethod @asynccontextmanager async def connect( cls, file: Path = THIS_DIR / '.chat_app_messages.sqlite' ) -> AsyncIterator[Database]: with logfire.span('connect to DB'): loop = asyncio.get_event_loop() executor = ThreadPoolExecutor(max_workers=1) con = await loop.run_in_executor(executor, cls._connect, file) slf = cls(con, loop, executor) try: yield slf finally: await slf._asyncify(con.close) @staticmethod def _connect(file: Path) -> sqlite3.Connection: con = sqlite3.connect(str(file)) con = logfire.instrument_sqlite3(con) cur = con.cursor() cur.execute( 'CREATE TABLE IF NOT EXISTS messages (id INT PRIMARY KEY, message_list TEXT);' ) con.commit() return con async def add_messages(self, messages: bytes): await self._asyncify( self._execute, 'INSERT INTO messages (message_list) VALUES (?);', messages, commit=True, ) await self._asyncify(self.con.commit) async def get_messages(self) -> list[ModelMessage]: c = await self._asyncify( self._execute, 'SELECT message_list FROM messages order by id' ) rows = await self._asyncify(c.fetchall) messages: list[ModelMessage] = [] for row in rows: messages.extend(ModelMessagesTypeAdapter.validate_json(row[0])) return messages def _execute( self, sql: LiteralString, *args: Any, commit: bool = False ) -> sqlite3.Cursor: cur = self.con.cursor() cur.execute(sql, args) if commit: self.con.commit() return cur async def _asyncify( self, func: Callable[P, R], *args: P.args, **kwargs: P.kwargs ) -> R: return await self._loop.run_in_executor( # type: ignore self._executor, partial(func, **kwargs), *args, # type: ignore ) if __name__ == '__main__': import uvicorn uvicorn.run( 'pydantic_ai_examples.chat_app:app', reload=True, reload_dirs=[str(THIS_DIR)] ) ``` TypeScript to handle rendering the messages, to keep this simple (and at the risk of offending frontend developers) the typescript code is passed to the browser as plain text and transpiled in the browser. [chat_app.ts](https://github.com/pydantic/pydantic-ai/blob/main/examples/pydantic_ai_examples/chat_app.ts) ```ts // BIG FAT WARNING: to avoid the complexity of npm, this typescript is compiled in the browser // there's currently no static type checking import { marked } from 'https://cdnjs.cloudflare.com/ajax/libs/marked/15.0.0/lib/marked.esm.js' const convElement = document.getElementById('conversation') const promptInput = document.getElementById('prompt-input') as HTMLInputElement const spinner = document.getElementById('spinner') // stream the response and render messages as each chunk is received // data is sent as newline-delimited JSON async function onFetchResponse(response: Response): Promise { let text = '' let decoder = new TextDecoder() if (response.ok) { const reader = response.body.getReader() while (true) { const {done, value} = await reader.read() if (done) { break } text += decoder.decode(value) addMessages(text) spinner.classList.remove('active') } addMessages(text) promptInput.disabled = false promptInput.focus() } else { const text = await response.text() console.error(`Unexpected response: ${response.status}`, {response, text}) throw new Error(`Unexpected response: ${response.status}`) } } // The format of messages, this matches pydantic-ai both for brevity and understanding // in production, you might not want to keep this format all the way to the frontend interface Message { role: string content: string timestamp: string } // take raw response text and render messages into the `#conversation` element // Message timestamp is assumed to be a unique identifier of a message, and is used to deduplicate // hence you can send data about the same message multiple times, and it will be updated // instead of creating a new message elements function addMessages(responseText: string) { const lines = responseText.split('\n') const messages: Message[] = lines.filter(line => line.length > 1).map(j => JSON.parse(j)) for (const message of messages) { // we use the timestamp as a crude element id const {timestamp, role, content} = message const id = `msg-${timestamp}` let msgDiv = document.getElementById(id) if (!msgDiv) { msgDiv = document.createElement('div') msgDiv.id = id msgDiv.title = `${role} at ${timestamp}` msgDiv.classList.add('border-top', 'pt-2', role) convElement.appendChild(msgDiv) } msgDiv.innerHTML = marked.parse(content) } window.scrollTo({ top: document.body.scrollHeight, behavior: 'smooth' }) } function onError(error: any) { console.error(error) document.getElementById('error').classList.remove('d-none') document.getElementById('spinner').classList.remove('active') } async function onSubmit(e: SubmitEvent): Promise { e.preventDefault() spinner.classList.add('active') const body = new FormData(e.target as HTMLFormElement) promptInput.value = '' promptInput.disabled = true const response = await fetch('/chat/', {method: 'POST', body}) await onFetchResponse(response) } // call onSubmit when the form is submitted (e.g. user clicks the send button or hits Enter) document.querySelector('form').addEventListener('submit', (e) => onSubmit(e).catch(onError)) // load messages on page load fetch('/chat/').then(onFetchResponse).catch(onError) ``` Example of a multi-agent flow where one agent delegates work to another, then hands off control to a third agent. Demonstrates: - [agent delegation](../../multi-agent-applications/#agent-delegation) - [programmatic agent hand-off](../../multi-agent-applications/#programmatic-agent-hand-off) - [usage limits](../../agents/#usage-limits) In this scenario, a group of agents work together to find the best flight for a user. The control flow for this example can be summarised as follows: ``` graph TD START --> search_agent("search agent") search_agent --> extraction_agent("extraction agent") extraction_agent --> search_agent search_agent --> human_confirm("human confirm") human_confirm --> search_agent search_agent --> FAILED human_confirm --> find_seat_function("find seat function") find_seat_function --> human_seat_choice("human seat choice") human_seat_choice --> find_seat_agent("find seat agent") find_seat_agent --> find_seat_function find_seat_function --> buy_flights("buy flights") buy_flights --> SUCCESS ``` ## Running the Example With [dependencies installed and environment variables set](../#usage), run: ```bash python -m pydantic_ai_examples.flight_booking ``` ```bash uv run -m pydantic_ai_examples.flight_booking ``` ## Example Code [flight_booking.py](https://github.com/pydantic/pydantic-ai/blob/main/examples/pydantic_ai_examples/flight_booking.py) ```py """Example of a multi-agent flow where one agent delegates work to another. In this scenario, a group of agents work together to find flights for a user. """ import datetime from dataclasses import dataclass from typing import Literal import logfire from pydantic import BaseModel, Field from rich.prompt import Prompt from pydantic_ai import Agent, ModelRetry, RunContext from pydantic_ai.messages import ModelMessage from pydantic_ai.usage import Usage, UsageLimits # 'if-token-present' means nothing will be sent (and the example will work) if you don't have logfire configured logfire.configure(send_to_logfire='if-token-present') logfire.instrument_pydantic_ai() class FlightDetails(BaseModel): """Details of the most suitable flight.""" flight_number: str price: int origin: str = Field(description='Three-letter airport code') destination: str = Field(description='Three-letter airport code') date: datetime.date class NoFlightFound(BaseModel): """When no valid flight is found.""" @dataclass class Deps: web_page_text: str req_origin: str req_destination: str req_date: datetime.date # This agent is responsible for controlling the flow of the conversation. search_agent = Agent[Deps, FlightDetails | NoFlightFound]( 'openai:gpt-4o', output_type=FlightDetails | NoFlightFound, # type: ignore retries=4, system_prompt=( 'Your job is to find the cheapest flight for the user on the given date. ' ), ) # This agent is responsible for extracting flight details from web page text. extraction_agent = Agent( 'openai:gpt-4o', output_type=list[FlightDetails], system_prompt='Extract all the flight details from the given text.', ) @search_agent.tool async def extract_flights(ctx: RunContext[Deps]) -> list[FlightDetails]: """Get details of all flights.""" # we pass the usage to the search agent so requests within this agent are counted result = await extraction_agent.run(ctx.deps.web_page_text, usage=ctx.usage) logfire.info('found {flight_count} flights', flight_count=len(result.output)) return result.output @search_agent.output_validator async def validate_output( ctx: RunContext[Deps], output: FlightDetails | NoFlightFound ) -> FlightDetails | NoFlightFound: """Procedural validation that the flight meets the constraints.""" if isinstance(output, NoFlightFound): return output errors: list[str] = [] if output.origin != ctx.deps.req_origin: errors.append( f'Flight should have origin {ctx.deps.req_origin}, not {output.origin}' ) if output.destination != ctx.deps.req_destination: errors.append( f'Flight should have destination {ctx.deps.req_destination}, not {output.destination}' ) if output.date != ctx.deps.req_date: errors.append(f'Flight should be on {ctx.deps.req_date}, not {output.date}') if errors: raise ModelRetry('\n'.join(errors)) else: return output class SeatPreference(BaseModel): row: int = Field(ge=1, le=30) seat: Literal['A', 'B', 'C', 'D', 'E', 'F'] class Failed(BaseModel): """Unable to extract a seat selection.""" # This agent is responsible for extracting the user's seat selection seat_preference_agent = Agent[None, SeatPreference | Failed]( 'openai:gpt-4o', output_type=SeatPreference | Failed, # type: ignore system_prompt=( "Extract the user's seat preference. " 'Seats A and F are window seats. ' 'Row 1 is the front row and has extra leg room. ' 'Rows 14, and 20 also have extra leg room. ' ), ) # in reality this would be downloaded from a booking site, # potentially using another agent to navigate the site flights_web_page = """ 1. Flight SFO-AK123 - Price: $350 - Origin: San Francisco International Airport (SFO) - Destination: Ted Stevens Anchorage International Airport (ANC) - Date: January 10, 2025 2. Flight SFO-AK456 - Price: $370 - Origin: San Francisco International Airport (SFO) - Destination: Fairbanks International Airport (FAI) - Date: January 10, 2025 3. Flight SFO-AK789 - Price: $400 - Origin: San Francisco International Airport (SFO) - Destination: Juneau International Airport (JNU) - Date: January 20, 2025 4. Flight NYC-LA101 - Price: $250 - Origin: San Francisco International Airport (SFO) - Destination: Ted Stevens Anchorage International Airport (ANC) - Date: January 10, 2025 5. Flight CHI-MIA202 - Price: $200 - Origin: Chicago O'Hare International Airport (ORD) - Destination: Miami International Airport (MIA) - Date: January 12, 2025 6. Flight BOS-SEA303 - Price: $120 - Origin: Boston Logan International Airport (BOS) - Destination: Ted Stevens Anchorage International Airport (ANC) - Date: January 12, 2025 7. Flight DFW-DEN404 - Price: $150 - Origin: Dallas/Fort Worth International Airport (DFW) - Destination: Denver International Airport (DEN) - Date: January 10, 2025 8. Flight ATL-HOU505 - Price: $180 - Origin: Hartsfield-Jackson Atlanta International Airport (ATL) - Destination: George Bush Intercontinental Airport (IAH) - Date: January 10, 2025 """ # restrict how many requests this app can make to the LLM usage_limits = UsageLimits(request_limit=15) async def main(): deps = Deps( web_page_text=flights_web_page, req_origin='SFO', req_destination='ANC', req_date=datetime.date(2025, 1, 10), ) message_history: list[ModelMessage] | None = None usage: Usage = Usage() # run the agent until a satisfactory flight is found while True: result = await search_agent.run( f'Find me a flight from {deps.req_origin} to {deps.req_destination} on {deps.req_date}', deps=deps, usage=usage, message_history=message_history, usage_limits=usage_limits, ) if isinstance(result.output, NoFlightFound): print('No flight found') break else: flight = result.output print(f'Flight found: {flight}') answer = Prompt.ask( 'Do you want to buy this flight, or keep searching? (buy/*search)', choices=['buy', 'search', ''], show_choices=False, ) if answer == 'buy': seat = await find_seat(usage) await buy_tickets(flight, seat) break else: message_history = result.all_messages( output_tool_return_content='Please suggest another flight' ) async def find_seat(usage: Usage) -> SeatPreference: message_history: list[ModelMessage] | None = None while True: answer = Prompt.ask('What seat would you like?') result = await seat_preference_agent.run( answer, message_history=message_history, usage=usage, usage_limits=usage_limits, ) if isinstance(result.output, SeatPreference): return result.output else: print('Could not understand seat preference. Please try again.') message_history = result.all_messages() async def buy_tickets(flight_details: FlightDetails, seat: SeatPreference): print(f'Purchasing flight {flight_details=!r} {seat=!r}...') if __name__ == '__main__': import asyncio asyncio.run(main()) ``` # Pydantic Model Simple example of using PydanticAI to construct a Pydantic model from a text input. Demonstrates: - [structured `output_type`](../../output/#structured-output) ## Running the Example With [dependencies installed and environment variables set](../#usage), run: ```bash python -m pydantic_ai_examples.pydantic_model ``` ```bash uv run -m pydantic_ai_examples.pydantic_model ``` This examples uses `openai:gpt-4o` by default, but it works well with other models, e.g. you can run it with Gemini using: ```bash PYDANTIC_AI_MODEL=gemini-1.5-pro python -m pydantic_ai_examples.pydantic_model ``` ```bash PYDANTIC_AI_MODEL=gemini-1.5-pro uv run -m pydantic_ai_examples.pydantic_model ``` (or `PYDANTIC_AI_MODEL=gemini-1.5-flash ...`) ## Example Code [pydantic_model.py](https://github.com/pydantic/pydantic-ai/blob/main/examples/pydantic_ai_examples/pydantic_model.py) ```py """Simple example of using PydanticAI to construct a Pydantic model from a text input. Run with: uv run -m pydantic_ai_examples.pydantic_model """ import os import logfire from pydantic import BaseModel from pydantic_ai import Agent # 'if-token-present' means nothing will be sent (and the example will work) if you don't have logfire configured logfire.configure(send_to_logfire='if-token-present') logfire.instrument_pydantic_ai() class MyModel(BaseModel): city: str country: str model = os.getenv('PYDANTIC_AI_MODEL', 'openai:gpt-4o') print(f'Using model: {model}') agent = Agent(model, output_type=MyModel) if __name__ == '__main__': result = agent.run_sync('The windy city in the US of A.') print(result.output) print(result.usage()) ``` # Question Graph Example of a graph for asking and evaluating questions. Demonstrates: - [`pydantic_graph`](../../graph/) ## Running the Example With [dependencies installed and environment variables set](../#usage), run: ```bash python -m pydantic_ai_examples.question_graph ``` ```bash uv run -m pydantic_ai_examples.question_graph ``` ## Example Code [question_graph.py](https://github.com/pydantic/pydantic-ai/blob/main/examples/pydantic_ai_examples/question_graph.py) ```py """Example of a graph for asking and evaluating questions. Run with: uv run -m pydantic_ai_examples.question_graph """ from __future__ import annotations as _annotations from dataclasses import dataclass, field from pathlib import Path import logfire from groq import BaseModel from pydantic_graph import ( BaseNode, End, Graph, GraphRunContext, ) from pydantic_graph.persistence.file import FileStatePersistence from pydantic_ai import Agent, format_as_xml from pydantic_ai.messages import ModelMessage # 'if-token-present' means nothing will be sent (and the example will work) if you don't have logfire configured logfire.configure(send_to_logfire='if-token-present') logfire.instrument_pydantic_ai() ask_agent = Agent('openai:gpt-4o', output_type=str) @dataclass class QuestionState: question: str | None = None ask_agent_messages: list[ModelMessage] = field(default_factory=list) evaluate_agent_messages: list[ModelMessage] = field(default_factory=list) @dataclass class Ask(BaseNode[QuestionState]): async def run(self, ctx: GraphRunContext[QuestionState]) -> Answer: result = await ask_agent.run( 'Ask a simple question with a single correct answer.', message_history=ctx.state.ask_agent_messages, ) ctx.state.ask_agent_messages += result.all_messages() ctx.state.question = result.output return Answer(result.output) @dataclass class Answer(BaseNode[QuestionState]): question: str async def run(self, ctx: GraphRunContext[QuestionState]) -> Evaluate: answer = input(f'{self.question}: ') return Evaluate(answer) class EvaluationOutput(BaseModel, use_attribute_docstrings=True): correct: bool """Whether the answer is correct.""" comment: str """Comment on the answer, reprimand the user if the answer is wrong.""" evaluate_agent = Agent( 'openai:gpt-4o', output_type=EvaluationOutput, system_prompt='Given a question and answer, evaluate if the answer is correct.', ) @dataclass class Evaluate(BaseNode[QuestionState, None, str]): answer: str async def run( self, ctx: GraphRunContext[QuestionState], ) -> End[str] | Reprimand: assert ctx.state.question is not None result = await evaluate_agent.run( format_as_xml({'question': ctx.state.question, 'answer': self.answer}), message_history=ctx.state.evaluate_agent_messages, ) ctx.state.evaluate_agent_messages += result.all_messages() if result.output.correct: return End(result.output.comment) else: return Reprimand(result.output.comment) @dataclass class Reprimand(BaseNode[QuestionState]): comment: str async def run(self, ctx: GraphRunContext[QuestionState]) -> Ask: print(f'Comment: {self.comment}') ctx.state.question = None return Ask() question_graph = Graph( nodes=(Ask, Answer, Evaluate, Reprimand), state_type=QuestionState ) async def run_as_continuous(): state = QuestionState() node = Ask() end = await question_graph.run(node, state=state) print('END:', end.output) async def run_as_cli(answer: str | None): persistence = FileStatePersistence(Path('question_graph.json')) persistence.set_graph_types(question_graph) if snapshot := await persistence.load_next(): state = snapshot.state assert answer is not None, ( 'answer required, usage "uv run -m pydantic_ai_examples.question_graph cli "' ) node = Evaluate(answer) else: state = QuestionState() node = Ask() # debug(state, node) async with question_graph.iter(node, state=state, persistence=persistence) as run: while True: node = await run.next() if isinstance(node, End): print('END:', node.data) history = await persistence.load_all() print('history:', '\n'.join(str(e.node) for e in history), sep='\n') print('Finished!') break elif isinstance(node, Answer): print(node.question) break # otherwise just continue if __name__ == '__main__': import asyncio import sys try: sub_command = sys.argv[1] assert sub_command in ('continuous', 'cli', 'mermaid') except (IndexError, AssertionError): print( 'Usage:\n' ' uv run -m pydantic_ai_examples.question_graph mermaid\n' 'or:\n' ' uv run -m pydantic_ai_examples.question_graph continuous\n' 'or:\n' ' uv run -m pydantic_ai_examples.question_graph cli [answer]', file=sys.stderr, ) sys.exit(1) if sub_command == 'mermaid': print(question_graph.mermaid_code(start_node=Ask)) elif sub_command == 'continuous': asyncio.run(run_as_continuous()) else: a = sys.argv[2] if len(sys.argv) > 2 else None asyncio.run(run_as_cli(a)) ``` The mermaid diagram generated in this example looks like this: ``` --- title: question_graph --- stateDiagram-v2 [*] --> Ask Ask --> Answer: ask the question Answer --> Evaluate: answer the question Evaluate --> Congratulate Evaluate --> Castigate Congratulate --> [*]: success Castigate --> Ask: try again ``` # RAG RAG search example. This demo allows you to ask question of the [logfire](https://pydantic.dev/logfire) documentation. Demonstrates: - [tools](../../tools/) - [agent dependencies](../../dependencies/) - RAG search This is done by creating a database containing each section of the markdown documentation, then registering the search tool with the PydanticAI agent. Logic for extracting sections from markdown files and a JSON file with that data is available in [this gist](https://gist.github.com/samuelcolvin/4b5bb9bb163b1122ff17e29e48c10992). [PostgreSQL with pgvector](https://github.com/pgvector/pgvector) is used as the search database, the easiest way to download and run pgvector is using Docker: ```bash mkdir postgres-data docker run --rm \ -e POSTGRES_PASSWORD=postgres \ -p 54320:5432 \ -v `pwd`/postgres-data:/var/lib/postgresql/data \ pgvector/pgvector:pg17 ``` As with the [SQL gen](../sql-gen/) example, we run postgres on port `54320` to avoid conflicts with any other postgres instances you may have running. We also mount the PostgreSQL `data` directory locally to persist the data if you need to stop and restart the container. With that running and [dependencies installed and environment variables set](../#usage), we can build the search database with (**WARNING**: this requires the `OPENAI_API_KEY` env variable and will calling the OpenAI embedding API around 300 times to generate embeddings for each section of the documentation): ```bash python -m pydantic_ai_examples.rag build ``` ```bash uv run -m pydantic_ai_examples.rag build ``` (Note building the database doesn't use PydanticAI right now, instead it uses the OpenAI SDK directly.) You can then ask the agent a question with: ```bash python -m pydantic_ai_examples.rag search "How do I configure logfire to work with FastAPI?" ``` ```bash uv run -m pydantic_ai_examples.rag search "How do I configure logfire to work with FastAPI?" ``` ## Example Code [rag.py](https://github.com/pydantic/pydantic-ai/blob/main/examples/pydantic_ai_examples/rag.py) ```py """RAG example with pydantic-ai — using vector search to augment a chat agent. Run pgvector with: mkdir postgres-data docker run --rm -e POSTGRES_PASSWORD=postgres \ -p 54320:5432 \ -v `pwd`/postgres-data:/var/lib/postgresql/data \ pgvector/pgvector:pg17 Build the search DB with: uv run -m pydantic_ai_examples.rag build Ask the agent a question with: uv run -m pydantic_ai_examples.rag search "How do I configure logfire to work with FastAPI?" """ from __future__ import annotations as _annotations import asyncio import re import sys import unicodedata from contextlib import asynccontextmanager from dataclasses import dataclass import asyncpg import httpx import logfire import pydantic_core from openai import AsyncOpenAI from pydantic import TypeAdapter from typing_extensions import AsyncGenerator from pydantic_ai import RunContext from pydantic_ai.agent import Agent # 'if-token-present' means nothing will be sent (and the example will work) if you don't have logfire configured logfire.configure(send_to_logfire='if-token-present') logfire.instrument_asyncpg() logfire.instrument_pydantic_ai() @dataclass class Deps: openai: AsyncOpenAI pool: asyncpg.Pool agent = Agent('openai:gpt-4o', deps_type=Deps) @agent.tool async def retrieve(context: RunContext[Deps], search_query: str) -> str: """Retrieve documentation sections based on a search query. Args: context: The call context. search_query: The search query. """ with logfire.span( 'create embedding for {search_query=}', search_query=search_query ): embedding = await context.deps.openai.embeddings.create( input=search_query, model='text-embedding-3-small', ) assert len(embedding.data) == 1, ( f'Expected 1 embedding, got {len(embedding.data)}, doc query: {search_query!r}' ) embedding = embedding.data[0].embedding embedding_json = pydantic_core.to_json(embedding).decode() rows = await context.deps.pool.fetch( 'SELECT url, title, content FROM doc_sections ORDER BY embedding <-> $1 LIMIT 8', embedding_json, ) return '\n\n'.join( f'# {row["title"]}\nDocumentation URL:{row["url"]}\n\n{row["content"]}\n' for row in rows ) async def run_agent(question: str): """Entry point to run the agent and perform RAG based question answering.""" openai = AsyncOpenAI() logfire.instrument_openai(openai) logfire.info('Asking "{question}"', question=question) async with database_connect(False) as pool: deps = Deps(openai=openai, pool=pool) answer = await agent.run(question, deps=deps) print(answer.output) ####################################################### # The rest of this file is dedicated to preparing the # # search database, and some utilities. # ####################################################### # JSON document from # https://gist.github.com/samuelcolvin/4b5bb9bb163b1122ff17e29e48c10992 DOCS_JSON = ( 'https://gist.githubusercontent.com/' 'samuelcolvin/4b5bb9bb163b1122ff17e29e48c10992/raw/' '80c5925c42f1442c24963aaf5eb1a324d47afe95/logfire_docs.json' ) async def build_search_db(): """Build the search database.""" async with httpx.AsyncClient() as client: response = await client.get(DOCS_JSON) response.raise_for_status() sections = sessions_ta.validate_json(response.content) openai = AsyncOpenAI() logfire.instrument_openai(openai) async with database_connect(True) as pool: with logfire.span('create schema'): async with pool.acquire() as conn: async with conn.transaction(): await conn.execute(DB_SCHEMA) sem = asyncio.Semaphore(10) async with asyncio.TaskGroup() as tg: for section in sections: tg.create_task(insert_doc_section(sem, openai, pool, section)) async def insert_doc_section( sem: asyncio.Semaphore, openai: AsyncOpenAI, pool: asyncpg.Pool, section: DocsSection, ) -> None: async with sem: url = section.url() exists = await pool.fetchval('SELECT 1 FROM doc_sections WHERE url = $1', url) if exists: logfire.info('Skipping {url=}', url=url) return with logfire.span('create embedding for {url=}', url=url): embedding = await openai.embeddings.create( input=section.embedding_content(), model='text-embedding-3-small', ) assert len(embedding.data) == 1, ( f'Expected 1 embedding, got {len(embedding.data)}, doc section: {section}' ) embedding = embedding.data[0].embedding embedding_json = pydantic_core.to_json(embedding).decode() await pool.execute( 'INSERT INTO doc_sections (url, title, content, embedding) VALUES ($1, $2, $3, $4)', url, section.title, section.content, embedding_json, ) @dataclass class DocsSection: id: int parent: int | None path: str level: int title: str content: str def url(self) -> str: url_path = re.sub(r'\.md$', '', self.path) return ( f'https://logfire.pydantic.dev/docs/{url_path}/#{slugify(self.title, "-")}' ) def embedding_content(self) -> str: return '\n\n'.join((f'path: {self.path}', f'title: {self.title}', self.content)) sessions_ta = TypeAdapter(list[DocsSection]) # pyright: reportUnknownMemberType=false # pyright: reportUnknownVariableType=false @asynccontextmanager async def database_connect( create_db: bool = False, ) -> AsyncGenerator[asyncpg.Pool, None]: server_dsn, database = ( 'postgresql://postgres:postgres@localhost:54320', 'pydantic_ai_rag', ) if create_db: with logfire.span('check and create DB'): conn = await asyncpg.connect(server_dsn) try: db_exists = await conn.fetchval( 'SELECT 1 FROM pg_database WHERE datname = $1', database ) if not db_exists: await conn.execute(f'CREATE DATABASE {database}') finally: await conn.close() pool = await asyncpg.create_pool(f'{server_dsn}/{database}') try: yield pool finally: await pool.close() DB_SCHEMA = """ CREATE EXTENSION IF NOT EXISTS vector; CREATE TABLE IF NOT EXISTS doc_sections ( id serial PRIMARY KEY, url text NOT NULL UNIQUE, title text NOT NULL, content text NOT NULL, -- text-embedding-3-small returns a vector of 1536 floats embedding vector(1536) NOT NULL ); CREATE INDEX IF NOT EXISTS idx_doc_sections_embedding ON doc_sections USING hnsw (embedding vector_l2_ops); """ def slugify(value: str, separator: str, unicode: bool = False) -> str: """Slugify a string, to make it URL friendly.""" # Taken unchanged from https://github.com/Python-Markdown/markdown/blob/3.7/markdown/extensions/toc.py#L38 if not unicode: # Replace Extended Latin characters with ASCII, i.e. `žlutý` => `zluty` value = unicodedata.normalize('NFKD', value) value = value.encode('ascii', 'ignore').decode('ascii') value = re.sub(r'[^\w\s-]', '', value).strip().lower() return re.sub(rf'[{separator}\s]+', separator, value) if __name__ == '__main__': action = sys.argv[1] if len(sys.argv) > 1 else None if action == 'build': asyncio.run(build_search_db()) elif action == 'search': if len(sys.argv) == 3: q = sys.argv[2] else: q = 'How do I configure logfire to work with FastAPI?' asyncio.run(run_agent(q)) else: print( 'uv run --extra examples -m pydantic_ai_examples.rag build|search', file=sys.stderr, ) sys.exit(1) ``` # Slack Lead Qualifier with Modal In this example, we're going to build an agentic app that: - automatically researches each new member that joins a company's public Slack community to see how good of a fit they are for the company's commercial product, - sends this analysis into a (private) Slack channel, and - sends a daily summary of the top 5 leads from the previous 24 hours into a (different) Slack channel. We'll be deploying the app on [Modal](https://modal.com), as it lets you use Python to define an app with web endpoints, scheduled functions, and background functions, and deploy them with a CLI, without needing to set up or manage any infrastructure. It's a great way to lower the barrier for people in your organization to start building and deploying AI agents to make their jobs easier. We also add [Pydantic Logfire](https://pydantic.dev/logfire) to get observability into the app and agent as they're running in response to webhooks and the schedule ## Screenshots This is what the analysis sent into Slack will look like: This is what the corresponding trace in [Logfire](https://pydantic.dev/logfire) will look like: All of these entries can be clicked on to get more details about what happened at that step, including the full conversation with the LLM and HTTP requests and responses. ## Prerequisites If you just want to see the code without actually going through the effort of setting up the bits necessary to run it, feel free to [jump ahead](#the-code). ### Slack app You need to have a Slack workspace and the necessary permissions to create apps. 2. Create a new Slack app using the instructions at . 1. In step 2, "Requesting scopes", request the following scopes: - [`users.read`](https://docs.slack.dev/reference/scopes/users.read) - [`users.read.email`](https://docs.slack.dev/reference/scopes/users.read.email) - [`users.profile.read`](https://docs.slack.dev/reference/scopes/users.profile.read) 1. In step 3, "Installing and authorizing the app", note down the Access Token as we're going to need to store it as a Secret in Modal. 1. You can skip steps 4 and 5. We're going to need to subscribe to the `team_join` event, but at this point you don't have a webhook URL yet. 1. Create the channels the app will post into, and add the Slack app to them: - `#new-slack-leads` - `#daily-slack-leads-summary` These names are hard-coded in the example. If you want to use different channels, you can clone the repo and change them in `examples/pydantic_examples/slack_lead_qualifier/functions.py`. ### Logfire Write Token 1. If you don't have a Logfire account yet, create one on . 1. Create a new project named, for example, `slack-lead-qualifier`. 1. Generate a new Write Token and note it down, as we're going to need to store it as a Secret in Modal. ### OpenAI API Key 1. If you don't have an OpenAI account yet, create one on . 1. Create a new API Key in Settings and note it down, as we're going to need to store it as a Secret in Modal. ### Modal account 1. If you don't have a Modal account yet, create one on . 1. Create 3 Secrets of type "Custom" on : - Name: `slack`, key: `SLACK_API_KEY`, value: the Slack Access Token you generated earlier - Name: `logfire`, key: `LOGFIRE_TOKEN`, value: the Logfire Write Token you generated earlier - Name: `openai`, key: `OPENAI_API_KEY`, value: the OpenAI API Key you generated earlier ## Usage 1. Make sure you have the [dependencies installed](../#usage). 1. Authenticate with Modal: ```bash python/uv-run -m modal setup ``` 1. Run the example as an [ephemeral Modal app](https://modal.com/docs/guide/apps#ephemeral-apps), meaning it will only run until you quit it using Ctrl+C: ```bash python/uv-run -m modal serve -m pydantic_ai_examples.slack_lead_qualifier.modal ``` 1. Note down the URL after `Created web function web_app =>`, this is your webhook endpoint URL. 1. Go back to and follow step 4, "Configuring the app for event listening", to subscribe to the `team_join` event with the webhook endpoint URL you noted down as the Request URL. Now when someone new (possibly you with a throwaway email) joins the Slack workspace, you'll see the webhook event being processed in the terminal where you ran `modal serve` and in the Logfire Live view, and after waiting a few seconds you should see the result appear in the `#new-slack-leads` Slack channel! Faking a Slack signup You can also fake a Slack signup event and try out the agent like this, with any name or email you please: ```bash curl -X POST \ -H "Content-Type: application/json" \ -d '{ "type": "event_callback", "event": { "type": "team_join", "user": { "profile": { "email": "samuel@pydantic.dev", "first_name": "Samuel", "last_name": "Colvin", "display_name": "Samuel Colvin" } } } }' ``` Deploying to production If you'd like to deploy this app into your Modal workspace in a persistent fashion, you can use this command: ```bash python/uv-run -m modal deploy -m pydantic_ai_examples.slack_lead_qualifier.modal ``` You'll likely want to [download the code](https://github.com/pydantic/pydantic-ai/tree/main/examples/pydantic_ai_examples/slack_lead_qualifier) first, put it in a new repo, and then do [continuous deployment](https://modal.com/docs/guide/continuous-deployment#github-actions) using GitHub Actions. Don't forget to update the Slack event request URL to the new persistent URL! You'll also want to modify the [instructions for the agent](#agent) to your own situation. ## The code We're going to start with the basics, and then gradually build up into the full app. ### Models #### `Profile` First, we define a [Pydantic](https://docs.pydantic.dev) model that represents a Slack user profile. These are the fields we get from the [`team_join`](https://docs.slack.dev/reference/events/team_join) event that's sent to the webhook endpoint that we'll define in a bit. [slack_lead_qualifier/models.py (L11-L15)](https://github.com/pydantic/pydantic-ai/blob/main/examples/pydantic_ai_examples/slack_lead_qualifier/models.py#L11-L15) ```py ... class Profile(BaseModel): first_name: str | None = None last_name: str | None = None display_name: str | None = None email: str ... ``` We also define a `Profile.as_prompt()` helper method that uses format_as_xml to turn the profile into a string that can be sent to the model. [slack_lead_qualifier/models.py (L7-L19)](https://github.com/pydantic/pydantic-ai/blob/main/examples/pydantic_ai_examples/slack_lead_qualifier/models.py#L7-L19) ```py ... from pydantic_ai import format_as_xml ... class Profile(BaseModel): ... def as_prompt(self) -> str: return format_as_xml(self, root_tag='profile') ... ``` #### `Analysis` The second model we'll need represents the result of the analysis that the agent will perform. We include docstrings to provide additional context to the model on what these fields should contain. [slack_lead_qualifier/models.py (L23-L31)](https://github.com/pydantic/pydantic-ai/blob/main/examples/pydantic_ai_examples/slack_lead_qualifier/models.py#L23-L31) ```py ... class Analysis(BaseModel): profile: Profile organization_name: str organization_domain: str job_title: str relevance: Annotated[int, Ge(1), Le(5)] """Estimated fit for Pydantic Logfire: 1 = low, 5 = high""" summary: str """One-sentence welcome note summarising who they are and how we might help""" ... ``` We also define a `Analysis.as_slack_blocks()` helper method that turns the analysis into some [Slack blocks](https://api.slack.com/reference/block-kit/blocks) that can be sent to the Slack API to post a new message. [slack_lead_qualifier/models.py (L23-L46)](https://github.com/pydantic/pydantic-ai/blob/main/examples/pydantic_ai_examples/slack_lead_qualifier/models.py#L23-L46) ```py ... class Analysis(BaseModel): ... def as_slack_blocks(self, include_relevance: bool = False) -> list[dict[str, Any]]: profile = self.profile relevance = f'({self.relevance}/5)' if include_relevance else '' return [ { 'type': 'markdown', 'text': f'[{profile.display_name}](mailto:{profile.email}), {self.job_title} at [**{self.organization_name}**](https://{self.organization_domain}) {relevance}', }, { 'type': 'markdown', 'text': self.summary, }, ] ``` ### Agent Now it's time to get into Pydantic AI and define the agent that will do the actual analysis! We specify the model we'll use (`openai:gpt-4o`), provide [instructions](../../agents/#instructions), give the agent access to the [DuckDuckGo search tool](../../common-tools/#duckduckgo-search-tool), and tell it to output either an `Analysis` or `None` using the [Native Output](../../output/#native-output) structured output mode. The real meat of the app is in the instructions that tell the agent how to evaluate each new Slack member. If you plan to use this app yourself, you'll of course want to modify them to your own situation. [slack_lead_qualifier/agent.py (L7-L40)](https://github.com/pydantic/pydantic-ai/blob/main/examples/pydantic_ai_examples/slack_lead_qualifier/agent.py#L7-L40) ```py ... from pydantic_ai import Agent, NativeOutput from pydantic_ai.common_tools.duckduckgo import duckduckgo_search_tool ... agent = Agent( 'openai:gpt-4o', instructions=dedent( """ When a new person joins our public Slack, please put together a brief snapshot so we can be most useful to them. **What to include** 1. **Who they are:** Any details about their professional role or projects (e.g. LinkedIn, GitHub, company bio). 2. **Where they work:** Name of the organisation and its domain. 3. **How we can help:** On a scale of 1–5, estimate how likely they are to benefit from **Pydantic Logfire** (our paid observability tool) based on factors such as company size, product maturity, or AI usage. *1 = probably not relevant, 5 = very strong fit.* **Our products (for context only)** • **Pydantic Validation** – Python data-validation (open source) • **Pydantic AI** – Python agent framework (open source) • **Pydantic Logfire** – Observability for traces, logs & metrics with first-class AI support (commercial) **How to research** • Use the provided DuckDuckGo search tool to research the person and the organization they work for, based on the email domain or what you find on e.g. LinkedIn and GitHub. • If you can't find enough to form a reasonable view, return **None**. """ ), tools=[duckduckgo_search_tool()], output_type=NativeOutput([Analysis, NoneType]), ) ... ``` #### `analyze_profile` We also define a `analyze_profile` helper function that takes a `Profile`, runs the agent, and returns an `Analysis` (or `None`), and instrument it using [Logfire](../../logfire/). [slack_lead_qualifier/agent.py (L44-L47)](https://github.com/pydantic/pydantic-ai/blob/main/examples/pydantic_ai_examples/slack_lead_qualifier/agent.py#L44-L47) ```py ... @logfire.instrument('Analyze profile') async def analyze_profile(profile: Profile) -> Analysis | None: result = await agent.run(profile.as_prompt()) return result.output ``` ### Analysis store The next building block we'll need is a place to store all the analyses that have been done so that we can look them up when we send the daily summary. Fortunately, Modal provides us with a convenient way to store some data that can be read back in a subsequent Modal run (webhook or scheduled): [`modal.Dict`](https://modal.com/docs/reference/modal.Dict). We define some convenience methods to easily add, list, and clear analyses. [slack_lead_qualifier/store.py (L4-L31)](https://github.com/pydantic/pydantic-ai/blob/main/examples/pydantic_ai_examples/slack_lead_qualifier/store.py#L4-L31) ```py ... import modal ... class AnalysisStore: @classmethod @logfire.instrument('Add analysis to store') async def add(cls, analysis: Analysis): await cls._get_store().put.aio(analysis.profile.email, analysis.model_dump()) @classmethod @logfire.instrument('List analyses from store') async def list(cls) -> list[Analysis]: return [ Analysis.model_validate(analysis) async for analysis in cls._get_store().values.aio() ] @classmethod @logfire.instrument('Clear analyses from store') async def clear(cls): await cls._get_store().clear.aio() @classmethod def _get_store(cls) -> modal.Dict: return modal.Dict.from_name('analyses', create_if_missing=True) # type: ignore ``` Note Note that `# type: ignore` on the last line -- unfortunately `modal` does not fully define its types, so we need this to stop our static type checker `pyright`, which we run over all Pydantic AI code including examples, from complaining. ### Send Slack message Next, we'll need a way to actually send a Slack message, so we define a simple function that uses Slack's [`chat.postMessage`](https://api.slack.com/methods/chat.postMessage) API. [slack_lead_qualifier/slack.py (L8-L30)](https://github.com/pydantic/pydantic-ai/blob/main/examples/pydantic_ai_examples/slack_lead_qualifier/slack.py#L8-L30) ```py ... API_KEY = os.getenv('SLACK_API_KEY') assert API_KEY, 'SLACK_API_KEY is not set' @logfire.instrument('Send Slack message') async def send_slack_message(channel: str, blocks: list[dict[str, Any]]): client = httpx.AsyncClient() response = await client.post( 'https://slack.com/api/chat.postMessage', json={ 'channel': channel, 'blocks': blocks, }, headers={ 'Authorization': f'Bearer {API_KEY}', }, timeout=5, ) response.raise_for_status() result = response.json() if not result.get('ok', False): error = result.get('error', 'Unknown error') raise Exception(f'Failed to send to Slack: {error}') ``` ### Features Now we can start putting these building blocks together to implement the actual features we want! #### `process_slack_member` This function takes a [`Profile`](#profile), [analyzes](#analyze_profile) it using the agent, adds it to the [`AnalysisStore`](#analysis-store), and [sends](#send-slack-message) the analysis into the `#new-slack-leads` channel. [slack_lead_qualifier/functions.py (L4-L45)](https://github.com/pydantic/pydantic-ai/blob/main/examples/pydantic_ai_examples/slack_lead_qualifier/functions.py#L4-L45) ```py ... from .agent import analyze_profile from .models import Profile from .slack import send_slack_message from .store import AnalysisStore ... NEW_LEAD_CHANNEL = '#new-slack-leads' ... @logfire.instrument('Process Slack member') async def process_slack_member(profile: Profile): analysis = await analyze_profile(profile) logfire.info('Analysis', analysis=analysis) if analysis is None: return await AnalysisStore().add(analysis) await send_slack_message( NEW_LEAD_CHANNEL, [ { 'type': 'header', 'text': { 'type': 'plain_text', 'text': f'New Slack member with score {analysis.relevance}/5', }, }, { 'type': 'divider', }, *analysis.as_slack_blocks(), ], ) ... ``` #### `send_daily_summary` This function list all of the analyses in the [`AnalysisStore`](#analysis-store), takes the top 5 by relevance, [sends](#send-slack-message) them into the `#daily-slack-leads-summary` channel, and clears the `AnalysisStore` so that the next daily run won't process these analyses again. [slack_lead_qualifier/functions.py (L8-L85)](https://github.com/pydantic/pydantic-ai/blob/main/examples/pydantic_ai_examples/slack_lead_qualifier/functions.py#L8-L85) ```py ... from .slack import send_slack_message from .store import AnalysisStore ... DAILY_SUMMARY_CHANNEL = '#daily-slack-leads-summary' ... @logfire.instrument('Send daily summary') async def send_daily_summary(): analyses = await AnalysisStore().list() logfire.info('Analyses', analyses=analyses) if len(analyses) == 0: return sorted_analyses = sorted(analyses, key=lambda x: x.relevance, reverse=True) top_analyses = sorted_analyses[:5] blocks = [ { 'type': 'header', 'text': { 'type': 'plain_text', 'text': f'Top {len(top_analyses)} new Slack members from the last 24 hours', }, }, ] for analysis in top_analyses: blocks.extend( [ { 'type': 'divider', }, *analysis.as_slack_blocks(include_relevance=True), ] ) await send_slack_message( DAILY_SUMMARY_CHANNEL, blocks, ) await AnalysisStore().clear() ``` ### Web app As it stands, neither of these functions are actually being called from anywhere. Let's implement a [FastAPI](https://fastapi.tiangolo.com/) endpoint to handle the `team_join` Slack webhook (also known as the [Slack Events API](https://docs.slack.dev/apis/events-api)) and call the [`process_slack_member`](#process_slack_member) function we just defined. We also instrument FastAPI using Logfire for good measure. [slack_lead_qualifier/app.py (L20-L36)](https://github.com/pydantic/pydantic-ai/blob/main/examples/pydantic_ai_examples/slack_lead_qualifier/app.py#L20-L36) ```py ... app = FastAPI() logfire.instrument_fastapi(app, capture_headers=True) @app.post('/') async def process_webhook(payload: dict[str, Any]) -> dict[str, Any]: if payload['type'] == 'url_verification': return {'challenge': payload['challenge']} elif ( payload['type'] == 'event_callback' and payload['event']['type'] == 'team_join' ): profile = Profile.model_validate(payload['event']['user']['profile']) process_slack_member(profile) return {'status': 'OK'} raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY) ``` #### `process_slack_member` with Modal I was a little sneaky there -- we're not actually calling the [`process_slack_member`](#process_slack_member) function we defined in `functions.py` directly, as Slack requires webhooks to respond within 3 seconds, and we need a bit more time than that to talk to the LLM, do some web searches, and send the Slack message. Instead, we're calling the following function defined alongside the app, which uses Modal's [`modal.Function.spawn`](https://modal.com/docs/reference/modal.Function#spawn) feature to run a function in the background. (If you're curious what the Modal side of this function looks like, you can [jump ahead](#backgrounded-process_slack_member).) Because `modal.py` (which we'll see in the next section) imports `app.py`, we import from `modal.py` inside the function definition because doing so at the top level would have resulted in a circular import error. We also pass along the current Logfire context to get [Distributed Tracing](https://logfire.pydantic.dev/docs/how-to-guides/distributed-tracing/), meaning that the background function execution will show up nested under the webhook request trace, so that we have everything related to that request in one place. [slack_lead_qualifier/app.py (L11-L16)](https://github.com/pydantic/pydantic-ai/blob/main/examples/pydantic_ai_examples/slack_lead_qualifier/app.py#L11-L16) ```py ... def process_slack_member(profile: Profile): from .modal import process_slack_member as _process_slack_member _process_slack_member.spawn( profile.model_dump(), logfire_ctx=get_context() ) ... ``` ### Modal app Now let's see how easy Modal makes it to deploy all of this. #### Set up Modal The first thing we do is define the Modal app, by specifying the base image to use (Debian with Python 3.13), all the Python packages it needs, and all of the secrets defined in the Modal interface that need to be made available during runtime. [slack_lead_qualifier/modal.py (L4-L21)](https://github.com/pydantic/pydantic-ai/blob/main/examples/pydantic_ai_examples/slack_lead_qualifier/modal.py#L4-L21) ```py ... import modal image = modal.Image.debian_slim(python_version='3.13').pip_install( 'pydantic', 'pydantic_ai_slim[openai,duckduckgo]', 'logfire[httpx,fastapi]', 'fastapi[standard]', 'httpx', ) app = modal.App( name='slack-lead-qualifier', image=image, secrets=[ modal.Secret.from_name('logfire'), modal.Secret.from_name('openai'), modal.Secret.from_name('slack'), ], ) ... ``` #### Set up Logfire Next, we define a function to set up Logfire instrumentation for Pydantic AI and HTTPX. We cannot do this at the top level of the file, as the requested packages (like `logfire`) will only be available within functions running on Modal (like the ones we'll define next). This file, `modal.py`, runs on your local machine and only has access to the `modal` package. [slack_lead_qualifier/modal.py (L25-L30)](https://github.com/pydantic/pydantic-ai/blob/main/examples/pydantic_ai_examples/slack_lead_qualifier/modal.py#L25-L30) ```py ... def setup_logfire(): import logfire logfire.configure(service_name=app.name) logfire.instrument_pydantic_ai() logfire.instrument_httpx(capture_all=True) ... ``` #### Web app To deploy a [web endpoint](https://modal.com/docs/guide/webhooks) on Modal, we simply define a function that returns an ASGI app (like FastAPI) and decorate it with `@app.function()` and `@modal.asgi_app()`. This `web_app` function will be run on Modal, so inside the function we can call the `setup_logfire` function that requires the `logfire` package, and import `app.py` which uses the other requested packages. By default, Modal spins up a container to handle a function call (like a web request) on-demand, meaning there's a little bit of startup time to each request. However, Slack requires webhooks to respond within 3 seconds, so we specify `min_containers=1` to keep the web endpoint running and ready to answer requests at all times. This is a bit annoying and wasteful, but fortunately [Modal's pricing](https://modal.com/pricing) is pretty reasonable, you get $30 free monthly compute, and they offer up to $50k in free credits for startup and academic researchers. [slack_lead_qualifier/modal.py (L34-L41)](https://github.com/pydantic/pydantic-ai/blob/main/examples/pydantic_ai_examples/slack_lead_qualifier/modal.py#L34-L41) ```py ... @app.function(min_containers=1) @modal.asgi_app() # type: ignore def web_app(): setup_logfire() from .app import app as _app return _app ... ``` Note Note that `# type: ignore` on the `@modal.asgi_app()` line -- unfortunately `modal` does not fully define its types, so we need this to stop our static type checker `pyright`, which we run over all Pydantic AI code including examples, from complaining. #### Scheduled `send_daily_summary` To define a [scheduled function](https://modal.com/docs/guide/cron), we can use the `@app.function()` decorator with a `schedule` argument. This Modal function will call our imported [`send_daily_summary`](#send_daily_summary) function every day at 8 am UTC. [slack_lead_qualifier/modal.py (L60-L66)](https://github.com/pydantic/pydantic-ai/blob/main/examples/pydantic_ai_examples/slack_lead_qualifier/modal.py#L60-L66) ```py ... @app.function(schedule=modal.Cron('0 8 * * *')) # Every day at 8am UTC async def send_daily_summary(): setup_logfire() from .functions import send_daily_summary as _send_daily_summary await _send_daily_summary() ``` #### Backgrounded `process_slack_member` Finally, we define a Modal function that wraps our [`process_slack_member`](#process_slack_member) function, so that it can run in the background. As you'll remember from when we [spawned this function from the web app](#process_slack_member-with-modal), we passed along the Logfire context to get [Distributed Tracing](https://logfire.pydantic.dev/docs/how-to-guides/distributed-tracing/), so we need to attach it here. [slack_lead_qualifier/modal.py (L45-L56)](https://github.com/pydantic/pydantic-ai/blob/main/examples/pydantic_ai_examples/slack_lead_qualifier/modal.py#L45-L56) ```py ... @app.function() async def process_slack_member(profile_raw: dict[str, Any], logfire_ctx: Any): setup_logfire() from logfire.propagate import attach_context from .functions import process_slack_member as _process_slack_member from .models import Profile with attach_context(logfire_ctx): profile = Profile.model_validate(profile_raw) await _process_slack_member(profile) ... ``` ## Conclusion And that's it! Now, assuming you've met the [prerequisites](#prerequisites), you can run or deploy the app using the commands under [usage](#usage). # SQL Generation Example demonstrating how to use PydanticAI to generate SQL queries based on user input. Demonstrates: - [dynamic system prompt](../../agents/#system-prompts) - [structured `output_type`](../../output/#structured-output) - [output validation](../../output/#output-validator-functions) - [agent dependencies](../../dependencies/) ## Running the Example The resulting SQL is validated by running it as an `EXPLAIN` query on PostgreSQL. To run the example, you first need to run PostgreSQL, e.g. via Docker: ```bash docker run --rm -e POSTGRES_PASSWORD=postgres -p 54320:5432 postgres ``` *(we run postgres on port `54320` to avoid conflicts with any other postgres instances you may have running)* With [dependencies installed and environment variables set](../#usage), run: ```bash python -m pydantic_ai_examples.sql_gen ``` ```bash uv run -m pydantic_ai_examples.sql_gen ``` or to use a custom prompt: ```bash python -m pydantic_ai_examples.sql_gen "find me errors" ``` ```bash uv run -m pydantic_ai_examples.sql_gen "find me errors" ``` This model uses `gemini-1.5-flash` by default since Gemini is good at single shot queries of this kind. ## Example Code [sql_gen.py](https://github.com/pydantic/pydantic-ai/blob/main/examples/pydantic_ai_examples/sql_gen.py) ```py """Example demonstrating how to use PydanticAI to generate SQL queries based on user input. Run postgres with: mkdir postgres-data docker run --rm -e POSTGRES_PASSWORD=postgres -p 54320:5432 postgres Run with: uv run -m pydantic_ai_examples.sql_gen "show me logs from yesterday, with level 'error'" """ import asyncio import sys from collections.abc import AsyncGenerator from contextlib import asynccontextmanager from dataclasses import dataclass from datetime import date from typing import Annotated, Any, Union import asyncpg import logfire from annotated_types import MinLen from devtools import debug from pydantic import BaseModel, Field from typing_extensions import TypeAlias from pydantic_ai import Agent, ModelRetry, RunContext, format_as_xml # 'if-token-present' means nothing will be sent (and the example will work) if you don't have logfire configured logfire.configure(send_to_logfire='if-token-present') logfire.instrument_asyncpg() logfire.instrument_pydantic_ai() DB_SCHEMA = """ CREATE TABLE records ( created_at timestamptz, start_timestamp timestamptz, end_timestamp timestamptz, trace_id text, span_id text, parent_span_id text, level log_level, span_name text, message text, attributes_json_schema text, attributes jsonb, tags text[], is_exception boolean, otel_status_message text, service_name text ); """ SQL_EXAMPLES = [ { 'request': 'show me records where foobar is false', 'response': "SELECT * FROM records WHERE attributes->>'foobar' = false", }, { 'request': 'show me records where attributes include the key "foobar"', 'response': "SELECT * FROM records WHERE attributes ? 'foobar'", }, { 'request': 'show me records from yesterday', 'response': "SELECT * FROM records WHERE start_timestamp::date > CURRENT_TIMESTAMP - INTERVAL '1 day'", }, { 'request': 'show me error records with the tag "foobar"', 'response': "SELECT * FROM records WHERE level = 'error' and 'foobar' = ANY(tags)", }, ] @dataclass class Deps: conn: asyncpg.Connection class Success(BaseModel): """Response when SQL could be successfully generated.""" sql_query: Annotated[str, MinLen(1)] explanation: str = Field( '', description='Explanation of the SQL query, as markdown' ) class InvalidRequest(BaseModel): """Response the user input didn't include enough information to generate SQL.""" error_message: str Response: TypeAlias = Union[Success, InvalidRequest] agent = Agent[Deps, Response]( 'google-gla:gemini-1.5-flash', # Type ignore while we wait for PEP-0747, nonetheless unions will work fine everywhere else output_type=Response, # type: ignore deps_type=Deps, ) @agent.system_prompt async def system_prompt() -> str: return f"""\ Given the following PostgreSQL table of records, your job is to write a SQL query that suits the user's request. Database schema: {DB_SCHEMA} today's date = {date.today()} {format_as_xml(SQL_EXAMPLES)} """ @agent.output_validator async def validate_output(ctx: RunContext[Deps], output: Response) -> Response: if isinstance(output, InvalidRequest): return output # gemini often adds extraneous backslashes to SQL output.sql_query = output.sql_query.replace('\\', '') if not output.sql_query.upper().startswith('SELECT'): raise ModelRetry('Please create a SELECT query') try: await ctx.deps.conn.execute(f'EXPLAIN {output.sql_query}') except asyncpg.exceptions.PostgresError as e: raise ModelRetry(f'Invalid query: {e}') from e else: return output async def main(): if len(sys.argv) == 1: prompt = 'show me logs from yesterday, with level "error"' else: prompt = sys.argv[1] async with database_connect( 'postgresql://postgres:postgres@localhost:54320', 'pydantic_ai_sql_gen' ) as conn: deps = Deps(conn) result = await agent.run(prompt, deps=deps) debug(result.output) # pyright: reportUnknownMemberType=false # pyright: reportUnknownVariableType=false @asynccontextmanager async def database_connect(server_dsn: str, database: str) -> AsyncGenerator[Any, None]: with logfire.span('check and create DB'): conn = await asyncpg.connect(server_dsn) try: db_exists = await conn.fetchval( 'SELECT 1 FROM pg_database WHERE datname = $1', database ) if not db_exists: await conn.execute(f'CREATE DATABASE {database}') finally: await conn.close() conn = await asyncpg.connect(f'{server_dsn}/{database}') try: with logfire.span('create schema'): async with conn.transaction(): if not db_exists: await conn.execute( "CREATE TYPE log_level AS ENUM ('debug', 'info', 'warning', 'error', 'critical')" ) await conn.execute(DB_SCHEMA) yield conn finally: await conn.close() if __name__ == '__main__': asyncio.run(main()) ``` This example shows how to stream markdown from an agent, using the [`rich`](https://github.com/Textualize/rich) library to highlight the output in the terminal. It'll run the example with both OpenAI and Google Gemini models if the required environment variables are set. Demonstrates: - [streaming text responses](../../output/#streaming-text) ## Running the Example With [dependencies installed and environment variables set](../#usage), run: ```bash python -m pydantic_ai_examples.stream_markdown ``` ```bash uv run -m pydantic_ai_examples.stream_markdown ``` ## Example Code [stream_markdown.py](https://github.com/pydantic/pydantic-ai/blob/main/examples/pydantic_ai_examples/stream_markdown.py) ```py """This example shows how to stream markdown from an agent, using the `rich` library to display the markdown. Run with: uv run -m pydantic_ai_examples.stream_markdown """ import asyncio import os import logfire from rich.console import Console, ConsoleOptions, RenderResult from rich.live import Live from rich.markdown import CodeBlock, Markdown from rich.syntax import Syntax from rich.text import Text from pydantic_ai import Agent from pydantic_ai.models import KnownModelName # 'if-token-present' means nothing will be sent (and the example will work) if you don't have logfire configured logfire.configure(send_to_logfire='if-token-present') logfire.instrument_pydantic_ai() agent = Agent() # models to try, and the appropriate env var models: list[tuple[KnownModelName, str]] = [ ('google-gla:gemini-1.5-flash', 'GEMINI_API_KEY'), ('openai:gpt-4o-mini', 'OPENAI_API_KEY'), ('groq:llama-3.3-70b-versatile', 'GROQ_API_KEY'), ] async def main(): prettier_code_blocks() console = Console() prompt = 'Show me a short example of using Pydantic.' console.log(f'Asking: {prompt}...', style='cyan') for model, env_var in models: if env_var in os.environ: console.log(f'Using model: {model}') with Live('', console=console, vertical_overflow='visible') as live: async with agent.run_stream(prompt, model=model) as result: async for message in result.stream(): live.update(Markdown(message)) console.log(result.usage()) else: console.log(f'{model} requires {env_var} to be set.') def prettier_code_blocks(): """Make rich code blocks prettier and easier to copy. From https://github.com/samuelcolvin/aicli/blob/v0.8.0/samuelcolvin_aicli.py#L22 """ class SimpleCodeBlock(CodeBlock): def __rich_console__( self, console: Console, options: ConsoleOptions ) -> RenderResult: code = str(self.text).rstrip() yield Text(self.lexer_name, style='dim') yield Syntax( code, self.lexer_name, theme=self.theme, background_color='default', word_wrap=True, ) yield Text(f'/{self.lexer_name}', style='dim') Markdown.elements['fence'] = SimpleCodeBlock if __name__ == '__main__': asyncio.run(main()) ``` Information about whales — an example of streamed structured response validation. Demonstrates: - [streaming structured output](../../output/#streaming-structured-output) This script streams structured responses from GPT-4 about whales, validates the data and displays it as a dynamic table using [`rich`](https://github.com/Textualize/rich) as the data is received. ## Running the Example With [dependencies installed and environment variables set](../#usage), run: ```bash python -m pydantic_ai_examples.stream_whales ``` ```bash uv run -m pydantic_ai_examples.stream_whales ``` Should give an output like this: ## Example Code [stream_whales.py](https://github.com/pydantic/pydantic-ai/blob/main/examples/pydantic_ai_examples/stream_whales.py) ```py """Information about whales — an example of streamed structured response validation. This script streams structured responses from GPT-4 about whales, validates the data and displays it as a dynamic table using Rich as the data is received. Run with: uv run -m pydantic_ai_examples.stream_whales """ from typing import Annotated import logfire from pydantic import Field from rich.console import Console from rich.live import Live from rich.table import Table from typing_extensions import NotRequired, TypedDict from pydantic_ai import Agent # 'if-token-present' means nothing will be sent (and the example will work) if you don't have logfire configured logfire.configure(send_to_logfire='if-token-present') logfire.instrument_pydantic_ai() class Whale(TypedDict): name: str length: Annotated[ float, Field(description='Average length of an adult whale in meters.') ] weight: NotRequired[ Annotated[ float, Field(description='Average weight of an adult whale in kilograms.', ge=50), ] ] ocean: NotRequired[str] description: NotRequired[Annotated[str, Field(description='Short Description')]] agent = Agent('openai:gpt-4', output_type=list[Whale]) async def main(): console = Console() with Live('\n' * 36, console=console) as live: console.print('Requesting data...', style='cyan') async with agent.run_stream( 'Generate me details of 5 species of Whale.' ) as result: console.print('Response:', style='green') async for whales in result.stream(debounce_by=0.01): table = Table( title='Species of Whale', caption='Streaming Structured responses from GPT-4', width=120, ) table.add_column('ID', justify='right') table.add_column('Name') table.add_column('Avg. Length (m)', justify='right') table.add_column('Avg. Weight (kg)', justify='right') table.add_column('Ocean') table.add_column('Description', justify='right') for wid, whale in enumerate(whales, start=1): table.add_row( str(wid), whale['name'], f'{whale["length"]:0.0f}', f'{w:0.0f}' if (w := whale.get('weight')) else '…', whale.get('ocean') or '…', whale.get('description') or '…', ) live.update(table) if __name__ == '__main__': import asyncio asyncio.run(main()) ``` Example of PydanticAI with multiple tools which the LLM needs to call in turn to answer a question. Demonstrates: - [tools](../../tools/) - [agent dependencies](../../dependencies/) - [streaming text responses](../../output/#streaming-text) - Building a [Gradio](https://www.gradio.app/) UI for the agent In this case the idea is a "weather" agent — the user can ask for the weather in multiple locations, the agent will use the `get_lat_lng` tool to get the latitude and longitude of the locations, then use the `get_weather` tool to get the weather for those locations. ## Running the Example To run this example properly, you might want to add two extra API keys **(Note if either key is missing, the code will fall back to dummy data, so they're not required)**: - A weather API key from [tomorrow.io](https://www.tomorrow.io/weather-api/) set via `WEATHER_API_KEY` - A geocoding API key from [geocode.maps.co](https://geocode.maps.co/) set via `GEO_API_KEY` With [dependencies installed and environment variables set](../#usage), run: ```bash python -m pydantic_ai_examples.weather_agent ``` ```bash uv run -m pydantic_ai_examples.weather_agent ``` ## Example Code [weather_agent.py](https://github.com/pydantic/pydantic-ai/blob/main/examples/pydantic_ai_examples/weather_agent.py) ```py """Example of PydanticAI with multiple tools which the LLM needs to call in turn to answer a question. In this case the idea is a "weather" agent — the user can ask for the weather in multiple cities, the agent will use the `get_lat_lng` tool to get the latitude and longitude of the locations, then use the `get_weather` tool to get the weather. Run with: uv run -m pydantic_ai_examples.weather_agent """ from __future__ import annotations as _annotations import asyncio from dataclasses import dataclass from typing import Any import logfire from httpx import AsyncClient from pydantic import BaseModel from pydantic_ai import Agent, RunContext # 'if-token-present' means nothing will be sent (and the example will work) if you don't have logfire configured logfire.configure(send_to_logfire='if-token-present') logfire.instrument_pydantic_ai() @dataclass class Deps: client: AsyncClient weather_agent = Agent( 'openai:gpt-4.1-mini', # 'Be concise, reply with one sentence.' is enough for some models (like openai) to use # the below tools appropriately, but others like anthropic and gemini require a bit more direction. instructions='Be concise, reply with one sentence.', deps_type=Deps, retries=2, ) class LatLng(BaseModel): lat: float lng: float @weather_agent.tool async def get_lat_lng(ctx: RunContext[Deps], location_description: str) -> LatLng: """Get the latitude and longitude of a location. Args: ctx: The context. location_description: A description of a location. """ # NOTE: the response here will be random, and is not related to the location description. r = await ctx.deps.client.get( 'https://demo-endpoints.pydantic.workers.dev/latlng', params={'location': location_description}, ) r.raise_for_status() return LatLng.model_validate_json(r.content) @weather_agent.tool async def get_weather(ctx: RunContext[Deps], lat: float, lng: float) -> dict[str, Any]: """Get the weather at a location. Args: ctx: The context. lat: Latitude of the location. lng: Longitude of the location. """ # NOTE: the responses here will be random, and are not related to the lat and lng. temp_response, descr_response = await asyncio.gather( ctx.deps.client.get( 'https://demo-endpoints.pydantic.workers.dev/number', params={'min': 10, 'max': 30}, ), ctx.deps.client.get( 'https://demo-endpoints.pydantic.workers.dev/weather', params={'lat': lat, 'lng': lng}, ), ) temp_response.raise_for_status() descr_response.raise_for_status() return { 'temperature': f'{temp_response.text} °C', 'description': descr_response.text, } async def main(): async with AsyncClient() as client: logfire.instrument_httpx(client, capture_all=True) deps = Deps(client=client) result = await weather_agent.run( 'What is the weather like in London and in Wiltshire?', deps=deps ) print('Response:', result.output) if __name__ == '__main__': asyncio.run(main()) ``` ## Running the UI You can build multi-turn chat applications for your agent with [Gradio](https://www.gradio.app/), a framework for building AI web applications entirely in python. Gradio comes with built-in chat components and agent support so the entire UI will be implemented in a single python file! Here's what the UI looks like for the weather agent: Note, to run the UI, you'll need Python 3.10+. ```bash pip install gradio>=5.9.0 python/uv-run -m pydantic_ai_examples.weather_agent_gradio ``` ## UI Code [weather_agent_gradio.py](https://github.com/pydantic/pydantic-ai/blob/main/examples/pydantic_ai_examples/weather_agent_gradio.py) ```py from __future__ import annotations as _annotations import json from httpx import AsyncClient from pydantic_ai.messages import ToolCallPart, ToolReturnPart from pydantic_ai_examples.weather_agent import Deps, weather_agent try: import gradio as gr except ImportError as e: raise ImportError( 'Please install gradio with `pip install gradio`. You must use python>=3.10.' ) from e TOOL_TO_DISPLAY_NAME = {'get_lat_lng': 'Geocoding API', 'get_weather': 'Weather API'} client = AsyncClient() deps = Deps(client=client) async def stream_from_agent(prompt: str, chatbot: list[dict], past_messages: list): chatbot.append({'role': 'user', 'content': prompt}) yield gr.Textbox(interactive=False, value=''), chatbot, gr.skip() async with weather_agent.run_stream( prompt, deps=deps, message_history=past_messages ) as result: for message in result.new_messages(): for call in message.parts: if isinstance(call, ToolCallPart): call_args = call.args_as_json_str() metadata = { 'title': f'🛠️ Using {TOOL_TO_DISPLAY_NAME[call.tool_name]}', } if call.tool_call_id is not None: metadata['id'] = call.tool_call_id gr_message = { 'role': 'assistant', 'content': 'Parameters: ' + call_args, 'metadata': metadata, } chatbot.append(gr_message) if isinstance(call, ToolReturnPart): for gr_message in chatbot: if ( gr_message.get('metadata', {}).get('id', '') == call.tool_call_id ): gr_message['content'] += ( f'\nOutput: {json.dumps(call.content)}' ) yield gr.skip(), chatbot, gr.skip() chatbot.append({'role': 'assistant', 'content': ''}) async for message in result.stream_text(): chatbot[-1]['content'] = message yield gr.skip(), chatbot, gr.skip() past_messages = result.all_messages() yield gr.Textbox(interactive=True), gr.skip(), past_messages async def handle_retry(chatbot, past_messages: list, retry_data: gr.RetryData): new_history = chatbot[: retry_data.index] previous_prompt = chatbot[retry_data.index]['content'] past_messages = past_messages[: retry_data.index] async for update in stream_from_agent(previous_prompt, new_history, past_messages): yield update def undo(chatbot, past_messages: list, undo_data: gr.UndoData): new_history = chatbot[: undo_data.index] past_messages = past_messages[: undo_data.index] return chatbot[undo_data.index]['content'], new_history, past_messages def select_data(message: gr.SelectData) -> str: return message.value['text'] with gr.Blocks() as demo: gr.HTML( """

Weather Assistant

This assistant answer your weather questions.

""" ) past_messages = gr.State([]) chatbot = gr.Chatbot( label='Packing Assistant', type='messages', avatar_images=(None, 'https://ai.pydantic.dev/img/logo-white.svg'), examples=[ {'text': 'What is the weather like in Miami?'}, {'text': 'What is the weather like in London?'}, ], ) with gr.Row(): prompt = gr.Textbox( lines=1, show_label=False, placeholder='What is the weather like in New York City?', ) generation = prompt.submit( stream_from_agent, inputs=[prompt, chatbot, past_messages], outputs=[prompt, chatbot, past_messages], ) chatbot.example_select(select_data, None, [prompt]) chatbot.retry( handle_retry, [chatbot, past_messages], [prompt, chatbot, past_messages] ) chatbot.undo(undo, [chatbot, past_messages], [prompt, chatbot, past_messages]) if __name__ == '__main__': demo.launch() ```