As with the SQL gen example, we run postgres on port 54320 to avoid conflicts with any other postgres instances you may have running.
We also mount the PostgreSQL data directory locally to persist the data if you need to stop and restart the container.
With that running and dependencies installed and environment variables set, we can build the search database with (WARNING: this requires the OPENAI_API_KEY env variable and will calling the OpenAI embedding API around 300 times to generate embeddings for each section of the documentation):
python-mpydantic_ai_examples.ragbuild
uvrun-mpydantic_ai_examples.ragbuild
(Note building the database doesn't use PydanticAI right now, instead it uses the OpenAI SDK directly.)
You can then ask the agent a question with:
python-mpydantic_ai_examples.ragsearch"How do I configure logfire to work with FastAPI?"
uvrun-mpydantic_ai_examples.ragsearch"How do I configure logfire to work with FastAPI?"
Example Code
rag.py
from__future__importannotationsas_annotationsimportasyncioimportreimportsysimportunicodedatafromcontextlibimportasynccontextmanagerfromdataclassesimportdataclassimportasyncpgimporthttpximportlogfireimportpydantic_corefromopenaiimportAsyncOpenAIfrompydanticimportTypeAdapterfromtyping_extensionsimportAsyncGeneratorfrompydantic_aiimportRunContextfrompydantic_ai.agentimportAgent# 'if-token-present' means nothing will be sent (and the example will work) if you don't have logfire configuredlogfire.configure(send_to_logfire='if-token-present')logfire.instrument_asyncpg()@dataclassclassDeps:openai:AsyncOpenAIpool:asyncpg.Poolagent=Agent('openai:gpt-4o',deps_type=Deps,instrument=True)@agent.toolasyncdefretrieve(context:RunContext[Deps],search_query:str)->str:"""Retrieve documentation sections based on a search query. Args: context: The call context. search_query: The search query. """withlogfire.span('create embedding for {search_query=}',search_query=search_query):embedding=awaitcontext.deps.openai.embeddings.create(input=search_query,model='text-embedding-3-small',)assertlen(embedding.data)==1,(f'Expected 1 embedding, got {len(embedding.data)}, doc query: {search_query!r}')embedding=embedding.data[0].embeddingembedding_json=pydantic_core.to_json(embedding).decode()rows=awaitcontext.deps.pool.fetch('SELECT url, title, content FROM doc_sections ORDER BY embedding <-> $1 LIMIT 8',embedding_json,)return'\n\n'.join(f'# {row["title"]}\nDocumentation URL:{row["url"]}\n\n{row["content"]}\n'forrowinrows)asyncdefrun_agent(question:str):"""Entry point to run the agent and perform RAG based question answering."""openai=AsyncOpenAI()logfire.instrument_openai(openai)logfire.info('Asking "{question}"',question=question)asyncwithdatabase_connect(False)aspool:deps=Deps(openai=openai,pool=pool)answer=awaitagent.run(question,deps=deps)print(answer.data)######################################################## The rest of this file is dedicated to preparing the ## search database, and some utilities. ######################################################### JSON document from# https://gist.github.com/samuelcolvin/4b5bb9bb163b1122ff17e29e48c10992DOCS_JSON=('https://gist.githubusercontent.com/''samuelcolvin/4b5bb9bb163b1122ff17e29e48c10992/raw/''80c5925c42f1442c24963aaf5eb1a324d47afe95/logfire_docs.json')asyncdefbuild_search_db():"""Build the search database."""asyncwithhttpx.AsyncClient()asclient:response=awaitclient.get(DOCS_JSON)response.raise_for_status()sections=sessions_ta.validate_json(response.content)openai=AsyncOpenAI()logfire.instrument_openai(openai)asyncwithdatabase_connect(True)aspool:withlogfire.span('create schema'):asyncwithpool.acquire()asconn:asyncwithconn.transaction():awaitconn.execute(DB_SCHEMA)sem=asyncio.Semaphore(10)asyncwithasyncio.TaskGroup()astg:forsectioninsections:tg.create_task(insert_doc_section(sem,openai,pool,section))asyncdefinsert_doc_section(sem:asyncio.Semaphore,openai:AsyncOpenAI,pool:asyncpg.Pool,section:DocsSection,)->None:asyncwithsem:url=section.url()exists=awaitpool.fetchval('SELECT 1 FROM doc_sections WHERE url = $1',url)ifexists:logfire.info('Skipping {url=}',url=url)returnwithlogfire.span('create embedding for {url=}',url=url):embedding=awaitopenai.embeddings.create(input=section.embedding_content(),model='text-embedding-3-small',)assertlen(embedding.data)==1,(f'Expected 1 embedding, got {len(embedding.data)}, doc section: {section}')embedding=embedding.data[0].embeddingembedding_json=pydantic_core.to_json(embedding).decode()awaitpool.execute('INSERT INTO doc_sections (url, title, content, embedding) VALUES ($1, $2, $3, $4)',url,section.title,section.content,embedding_json,)@dataclassclassDocsSection:id:intparent:int|Nonepath:strlevel:inttitle:strcontent:strdefurl(self)->str:url_path=re.sub(r'\.md$','',self.path)return(f'https://logfire.pydantic.dev/docs/{url_path}/#{slugify(self.title,"-")}')defembedding_content(self)->str:return'\n\n'.join((f'path: {self.path}',f'title: {self.title}',self.content))sessions_ta=TypeAdapter(list[DocsSection])# pyright: reportUnknownMemberType=false# pyright: reportUnknownVariableType=false@asynccontextmanagerasyncdefdatabase_connect(create_db:bool=False,)->AsyncGenerator[asyncpg.Pool,None]:server_dsn,database=('postgresql://postgres:postgres@localhost:54320','pydantic_ai_rag',)ifcreate_db:withlogfire.span('check and create DB'):conn=awaitasyncpg.connect(server_dsn)try:db_exists=awaitconn.fetchval('SELECT 1 FROM pg_database WHERE datname = $1',database)ifnotdb_exists:awaitconn.execute(f'CREATE DATABASE {database}')finally:awaitconn.close()pool=awaitasyncpg.create_pool(f'{server_dsn}/{database}')try:yieldpoolfinally:awaitpool.close()DB_SCHEMA="""CREATE EXTENSION IF NOT EXISTS vector;CREATE TABLE IF NOT EXISTS doc_sections ( id serial PRIMARY KEY, url text NOT NULL UNIQUE, title text NOT NULL, content text NOT NULL, -- text-embedding-3-small returns a vector of 1536 floats embedding vector(1536) NOT NULL);CREATE INDEX IF NOT EXISTS idx_doc_sections_embedding ON doc_sections USING hnsw (embedding vector_l2_ops);"""defslugify(value:str,separator:str,unicode:bool=False)->str:"""Slugify a string, to make it URL friendly."""# Taken unchanged from https://github.com/Python-Markdown/markdown/blob/3.7/markdown/extensions/toc.py#L38ifnotunicode:# Replace Extended Latin characters with ASCII, i.e. `žlutý` => `zluty`value=unicodedata.normalize('NFKD',value)value=value.encode('ascii','ignore').decode('ascii')value=re.sub(r'[^\w\s-]','',value).strip().lower()returnre.sub(rf'[{separator}\s]+',separator,value)if__name__=='__main__':action=sys.argv[1]iflen(sys.argv)>1elseNoneifaction=='build':asyncio.run(build_search_db())elifaction=='search':iflen(sys.argv)==3:q=sys.argv[2]else:q='How do I configure logfire to work with FastAPI?'asyncio.run(run_agent(q))else:print('uv run --extra examples -m pydantic_ai_examples.rag build|search',file=sys.stderr,)sys.exit(1)