Joins and Reducers
Join nodes synchronize and aggregate data from parallel execution paths. They use Reducers to combine multiple inputs into a single output.
Overview
When you use parallel execution (broadcasting or mapping), you often need to collect and combine the results. Join nodes serve this purpose by:
- Waiting for all parallel tasks to complete
- Aggregating their outputs using a
ReducerFunction - Passing the aggregated result to the next node
Creating Joins
Create a join using GraphBuilder.join with a reducer function and initial value or factory:
from dataclasses import dataclass
from pydantic_graph.beta import GraphBuilder, StepContext
from pydantic_graph.beta.join import reduce_list_append
@dataclass
class SimpleState:
pass
g = GraphBuilder(state_type=SimpleState, output_type=list[int])
@g.step
async def generate_numbers(ctx: StepContext[SimpleState, None, None]) -> list[int]:
return [1, 2, 3, 4, 5]
@g.step
async def square(ctx: StepContext[SimpleState, None, int]) -> int:
return ctx.inputs * ctx.inputs
# Create a join to collect all squared values
collect = g.join(reduce_list_append, initial_factory=list[int])
g.add(
g.edge_from(g.start_node).to(generate_numbers),
g.edge_from(generate_numbers).map().to(square),
g.edge_from(square).to(collect),
g.edge_from(collect).to(g.end_node),
)
graph = g.build()
async def main():
result = await graph.run(state=SimpleState())
print(sorted(result))
#> [1, 4, 9, 16, 25]
(This example is complete, it can be run "as is" — you'll need to add import asyncio; asyncio.run(main()) to run main)
Built-in Reducers
Pydantic Graph provides several common reducer types out of the box:
reduce_list_append
reduce_list_append collects all inputs into a list:
from dataclasses import dataclass
from pydantic_graph.beta import GraphBuilder, StepContext
from pydantic_graph.beta.join import reduce_list_append
@dataclass
class SimpleState:
pass
async def main():
g = GraphBuilder(state_type=SimpleState, output_type=list[str])
@g.step
async def generate(ctx: StepContext[SimpleState, None, None]) -> list[int]:
return [10, 20, 30]
@g.step
async def to_string(ctx: StepContext[SimpleState, None, int]) -> str:
return f'value-{ctx.inputs}'
collect = g.join(reduce_list_append, initial_factory=list[str])
g.add(
g.edge_from(g.start_node).to(generate),
g.edge_from(generate).map().to(to_string),
g.edge_from(to_string).to(collect),
g.edge_from(collect).to(g.end_node),
)
graph = g.build()
result = await graph.run(state=SimpleState())
print(sorted(result))
#> ['value-10', 'value-20', 'value-30']
(This example is complete, it can be run "as is" — you'll need to add import asyncio; asyncio.run(main()) to run main)
reduce_list_extend
reduce_list_extend extends a list with an iterable of items:
from dataclasses import dataclass
from pydantic_graph.beta import GraphBuilder, StepContext
from pydantic_graph.beta.join import reduce_list_extend
@dataclass
class SimpleState:
pass
async def main():
g = GraphBuilder(state_type=SimpleState, output_type=list[int])
@g.step
async def generate(ctx: StepContext[SimpleState, None, None]) -> list[int]:
return [1, 2, 3]
@g.step
async def create_range(ctx: StepContext[SimpleState, None, int]) -> list[int]:
"""Create a range from 0 to the input value."""
return list(range(ctx.inputs))
collect = g.join(reduce_list_extend, initial_factory=list[int])
g.add(
g.edge_from(g.start_node).to(generate),
g.edge_from(generate).map().to(create_range),
g.edge_from(create_range).to(collect),
g.edge_from(collect).to(g.end_node),
)
graph = g.build()
result = await graph.run(state=SimpleState())
print(sorted(result))
#> [0, 0, 0, 1, 1, 2]
(This example is complete, it can be run "as is" — you'll need to add import asyncio; asyncio.run(main()) to run main)
reduce_dict_update
reduce_dict_update merges dictionaries together:
from dataclasses import dataclass
from pydantic_graph.beta import GraphBuilder, StepContext
from pydantic_graph.beta.join import reduce_dict_update
@dataclass
class SimpleState:
pass
async def main():
g = GraphBuilder(state_type=SimpleState, output_type=dict[str, int])
@g.step
async def generate_keys(ctx: StepContext[SimpleState, None, None]) -> list[str]:
return ['apple', 'banana', 'cherry']
@g.step
async def create_entry(ctx: StepContext[SimpleState, None, str]) -> dict[str, int]:
return {ctx.inputs: len(ctx.inputs)}
merge = g.join(reduce_dict_update, initial_factory=dict[str, int])
g.add(
g.edge_from(g.start_node).to(generate_keys),
g.edge_from(generate_keys).map().to(create_entry),
g.edge_from(create_entry).to(merge),
g.edge_from(merge).to(g.end_node),
)
graph = g.build()
result = await graph.run(state=SimpleState())
result = {k: result[k] for k in sorted(result)} # force deterministic ordering
print(result)
#> {'apple': 5, 'banana': 6, 'cherry': 6}
(This example is complete, it can be run "as is" — you'll need to add import asyncio; asyncio.run(main()) to run main)
reduce_null
reduce_null discards all inputs and returns None. Useful when you only care about side effects:
from dataclasses import dataclass
from pydantic_graph.beta import GraphBuilder, StepContext
from pydantic_graph.beta.join import reduce_null
@dataclass
class CounterState:
total: int = 0
async def main():
g = GraphBuilder(state_type=CounterState, output_type=int)
@g.step
async def generate(ctx: StepContext[CounterState, None, None]) -> list[int]:
return [1, 2, 3, 4, 5]
@g.step
async def accumulate(ctx: StepContext[CounterState, None, int]) -> int:
ctx.state.total += ctx.inputs
return ctx.inputs
# We don't care about the outputs, only the side effect on state
ignore = g.join(reduce_null, initial=None)
@g.step
async def get_total(ctx: StepContext[CounterState, None, None]) -> int:
return ctx.state.total
g.add(
g.edge_from(g.start_node).to(generate),
g.edge_from(generate).map().to(accumulate),
g.edge_from(accumulate).to(ignore),
g.edge_from(ignore).to(get_total),
g.edge_from(get_total).to(g.end_node),
)
graph = g.build()
state = CounterState()
result = await graph.run(state=state)
print(result)
#> 15
(This example is complete, it can be run "as is" — you'll need to add import asyncio; asyncio.run(main()) to run main)
reduce_sum
reduce_sum sums numeric values:
from dataclasses import dataclass
from pydantic_graph.beta import GraphBuilder, StepContext
from pydantic_graph.beta.join import reduce_sum
@dataclass
class SimpleState:
pass
async def main():
g = GraphBuilder(state_type=SimpleState, output_type=int)
@g.step
async def generate(ctx: StepContext[SimpleState, None, None]) -> list[int]:
return [10, 20, 30, 40]
@g.step
async def identity(ctx: StepContext[SimpleState, None, int]) -> int:
return ctx.inputs
sum_join = g.join(reduce_sum, initial=0)
g.add(
g.edge_from(g.start_node).to(generate),
g.edge_from(generate).map().to(identity),
g.edge_from(identity).to(sum_join),
g.edge_from(sum_join).to(g.end_node),
)
graph = g.build()
result = await graph.run(state=SimpleState())
print(result)
#> 100
(This example is complete, it can be run "as is" — you'll need to add import asyncio; asyncio.run(main()) to run main)
ReduceFirstValue
ReduceFirstValue returns the first value it receives and cancels all other parallel tasks. This is useful for "race" scenarios where you want the first successful result:
import asyncio
from dataclasses import dataclass
from pydantic_graph.beta import GraphBuilder, StepContext
from pydantic_graph.beta.join import ReduceFirstValue
@dataclass
class SimpleState:
tasks_completed: int = 0
async def main():
g = GraphBuilder(state_type=SimpleState, output_type=str)
@g.step
async def generate(ctx: StepContext[SimpleState, None, None]) -> list[int]:
return [1, 12, 13, 14, 15]
@g.step
async def slow_process(ctx: StepContext[SimpleState, None, int]) -> str:
"""Simulate variable processing times."""
# Simulate different delays
await asyncio.sleep(ctx.inputs * 0.1)
ctx.state.tasks_completed += 1
return f'Result from task {ctx.inputs}'
# Use ReduceFirstValue to get the first result and cancel the rest
first_result = g.join(ReduceFirstValue[str](), initial=None, node_id='first_result')
g.add(
g.edge_from(g.start_node).to(generate),
g.edge_from(generate).map().to(slow_process),
g.edge_from(slow_process).to(first_result),
g.edge_from(first_result).to(g.end_node),
)
graph = g.build()
state = SimpleState()
result = await graph.run(state=state)
print(result)
#> Result from task 1
print(f'Tasks completed: {state.tasks_completed}')
#> Tasks completed: 1
(This example is complete, it can be run "as is" — you'll need to add import asyncio; asyncio.run(main()) to run main)
Custom Reducers
Create custom reducers by defining a ReducerFunction:
from pydantic_graph.beta import GraphBuilder, StepContext
def reduce_sum(current: int, inputs: int) -> int:
"""A reducer that sums numbers."""
return current + inputs
async def main():
g = GraphBuilder(output_type=int)
@g.step
async def generate(ctx: StepContext[None, None, None]) -> list[int]:
return [5, 10, 15, 20]
@g.step
async def identity(ctx: StepContext[None, None, int]) -> int:
return ctx.inputs
sum_join = g.join(reduce_sum, initial=0)
g.add(
g.edge_from(g.start_node).to(generate),
g.edge_from(generate).map().to(identity),
g.edge_from(identity).to(sum_join),
g.edge_from(sum_join).to(g.end_node),
)
graph = g.build()
result = await graph.run()
print(result)
#> 50
(This example is complete, it can be run "as is" — you'll need to add import asyncio; asyncio.run(main()) to run main)
Reducers with State Access
Reducers can access and modify the graph state:
from dataclasses import dataclass
from pydantic_graph.beta import GraphBuilder, StepContext
from pydantic_graph.beta.join import ReducerContext
@dataclass
class MetricsState:
total_count: int = 0
total_sum: int = 0
@dataclass
class ReducedMetrics:
count: int = 0
sum: int = 0
def reduce_metrics_sum(ctx: ReducerContext[MetricsState, None], current: ReducedMetrics, inputs: int) -> ReducedMetrics:
ctx.state.total_count += 1
ctx.state.total_sum += inputs
return ReducedMetrics(count=current.count + 1, sum=current.sum + inputs)
def reduce_metrics_max(current: ReducedMetrics, inputs: ReducedMetrics) -> ReducedMetrics:
return ReducedMetrics(count=max(current.count, inputs.count), sum=max(current.sum, inputs.sum))
async def main():
g = GraphBuilder(state_type=MetricsState, output_type=dict[str, int])
@g.step
async def generate(ctx: StepContext[object, None, None]) -> list[int]:
return [1, 3, 5, 7, 9, 10, 20, 30, 40]
@g.step
async def process_even(ctx: StepContext[MetricsState, None, int]) -> int:
return ctx.inputs * 2
@g.step
async def process_odd(ctx: StepContext[MetricsState, None, int]) -> int:
return ctx.inputs * 3
metrics_even = g.join(reduce_metrics_sum, initial_factory=ReducedMetrics, node_id='metrics_even')
metrics_odd = g.join(reduce_metrics_sum, initial_factory=ReducedMetrics, node_id='metrics_odd')
metrics_max = g.join(reduce_metrics_max, initial_factory=ReducedMetrics, node_id='metrics_max')
g.add(
g.edge_from(g.start_node).to(generate),
# Send even and odd numbers to their respective `process` steps
g.edge_from(generate).map().to(
g.decision()
.branch(g.match(int, matches=lambda x: x % 2 == 0).label('even').to(process_even))
.branch(g.match(int, matches=lambda x: x % 2 == 1).label('odd').to(process_odd))
),
# Reduce metrics for even and odd numbers separately
g.edge_from(process_even).to(metrics_even),
g.edge_from(process_odd).to(metrics_odd),
# Aggregate the max values for each field
g.edge_from(metrics_even).to(metrics_max),
g.edge_from(metrics_odd).to(metrics_max),
# Finish the graph run with the final reduced value
g.edge_from(metrics_max).to(g.end_node),
)
graph = g.build()
state = MetricsState()
result = await graph.run(state=state)
print(f'Result: {result}')
#> Result: ReducedMetrics(count=5, sum=200)
print(f'State total_count: {state.total_count}')
#> State total_count: 9
print(f'State total_sum: {state.total_sum}')
#> State total_sum: 275
(This example is complete, it can be run "as is" — you'll need to add import asyncio; asyncio.run(main()) to run main)
Canceling Sibling Tasks
Reducers with access to ReducerContext can call ctx.cancel_sibling_tasks() to cancel all other parallel tasks in the same fork. This is useful for early termination when you've found what you need:
import asyncio
from dataclasses import dataclass
from pydantic_graph.beta import GraphBuilder, StepContext
from pydantic_graph.beta.join import ReducerContext
@dataclass
class SearchState:
searches_completed: int = 0
def reduce_find_match(ctx: ReducerContext[SearchState, None], current: str | None, inputs: str) -> str | None:
"""Return the first input that contains 'target' and cancel remaining tasks."""
if current is not None:
# We already found a match, ignore subsequent inputs
return current
if 'target' in inputs:
# Found a match! Cancel all other parallel tasks
ctx.cancel_sibling_tasks()
return inputs
return None
async def main():
g = GraphBuilder(state_type=SearchState, output_type=str | None)
@g.step
async def generate_searches(ctx: StepContext[SearchState, None, None]) -> list[str]:
return ['item1', 'item2', 'target_item', 'item4', 'item5']
@g.step
async def search(ctx: StepContext[SearchState, None, str]) -> str:
"""Simulate a slow search operation."""
# make the search artificially slower for 'item4' and 'item5'
search_duration = 0.1 if ctx.inputs not in {'item4', 'item5'} else 1.0
await asyncio.sleep(search_duration)
ctx.state.searches_completed += 1
return ctx.inputs
find_match = g.join(reduce_find_match, initial=None)
g.add(
g.edge_from(g.start_node).to(generate_searches),
g.edge_from(generate_searches).map().to(search),
g.edge_from(search).to(find_match),
g.edge_from(find_match).to(g.end_node),
)
graph = g.build()
state = SearchState()
result = await graph.run(state=state)
print(f'Found: {result}')
#> Found: target_item
print(f'Searches completed: {state.searches_completed}')
#> Searches completed: 3
(This example is complete, it can be run "as is" — you'll need to add import asyncio; asyncio.run(main()) to run main)
Note that only 3 searches completed instead of all 5, because the reducer canceled the remaining tasks after finding a match.
Multiple Joins
A graph can have multiple independent joins:
from dataclasses import dataclass, field
from pydantic_graph.beta import GraphBuilder, StepContext
from pydantic_graph.beta.join import reduce_list_append
@dataclass
class MultiState:
results: dict[str, list[int]] = field(default_factory=dict)
async def main():
g = GraphBuilder(state_type=MultiState, output_type=dict[str, list[int]])
@g.step
async def source_a(ctx: StepContext[MultiState, None, None]) -> list[int]:
return [1, 2, 3]
@g.step
async def source_b(ctx: StepContext[MultiState, None, None]) -> list[int]:
return [10, 20]
@g.step
async def process_a(ctx: StepContext[MultiState, None, int]) -> int:
return ctx.inputs * 2
@g.step
async def process_b(ctx: StepContext[MultiState, None, int]) -> int:
return ctx.inputs * 3
join_a = g.join(reduce_list_append, initial_factory=list[int], node_id='join_a')
join_b = g.join(reduce_list_append, initial_factory=list[int], node_id='join_b')
@g.step
async def store_a(ctx: StepContext[MultiState, None, list[int]]) -> None:
ctx.state.results['a'] = ctx.inputs
@g.step
async def store_b(ctx: StepContext[MultiState, None, list[int]]) -> None:
ctx.state.results['b'] = ctx.inputs
@g.step
async def combine(ctx: StepContext[MultiState, None, None]) -> dict[str, list[int]]:
return ctx.state.results
g.add(
g.edge_from(g.start_node).to(source_a, source_b),
g.edge_from(source_a).map().to(process_a),
g.edge_from(source_b).map().to(process_b),
g.edge_from(process_a).to(join_a),
g.edge_from(process_b).to(join_b),
g.edge_from(join_a).to(store_a),
g.edge_from(join_b).to(store_b),
g.edge_from(store_a, store_b).to(combine),
g.edge_from(combine).to(g.end_node),
)
graph = g.build()
state = MultiState()
result = await graph.run(state=state)
print(f"Group A: {sorted(result['a'])}")
#> Group A: [2, 4, 6]
print(f"Group B: {sorted(result['b'])}")
#> Group B: [30, 60]
(This example is complete, it can be run "as is" — you'll need to add import asyncio; asyncio.run(main()) to run main)
Customizing Join Nodes
Custom Node IDs
Like steps, joins can have custom IDs:
from pydantic_graph.beta.join import reduce_list_append
from basic_join import g
my_join = g.join(reduce_list_append, initial_factory=list[int], node_id='my_custom_join_id')
How Joins Work
Internally, the graph tracks which "fork" each parallel task belongs to. A join:
- Identifies its parent fork (the fork that created the parallel paths)
- Waits for all tasks from that fork to reach the join
- Calls
reduce()for each incoming value - Calls
finalize()once all values are received - Passes the finalized result to downstream nodes
This ensures proper synchronization even with nested parallel operations.
Next Steps
- Learn about parallel execution with broadcasting and mapping
- Explore conditional branching with decision nodes
- See the API reference for complete reducer documentation