diff --git a/src/agents/__init__.py b/src/agents/__init__.py index 00a5ca21e..e3a8c4cdd 100644 --- a/src/agents/__init__.py +++ b/src/agents/__init__.py @@ -65,6 +65,7 @@ OpenAIConversationsSession, Session, SessionABC, + SessionSettings, SQLiteSession, ) from .model_settings import ModelSettings @@ -286,6 +287,7 @@ def enable_verbose_stdout_logging(): "AgentHooks", "Session", "SessionABC", + "SessionSettings", "SQLiteSession", "OpenAIConversationsSession", "RunContextWrapper", diff --git a/src/agents/extensions/memory/advanced_sqlite_session.py b/src/agents/extensions/memory/advanced_sqlite_session.py index fefb73026..a4fb4c8b9 100644 --- a/src/agents/extensions/memory/advanced_sqlite_session.py +++ b/src/agents/extensions/memory/advanced_sqlite_session.py @@ -13,6 +13,7 @@ from ...items import TResponseInputItem from ...memory import SQLiteSession +from ...memory.session_settings import SessionSettings class AdvancedSQLiteSession(SQLiteSession): @@ -25,6 +26,7 @@ def __init__( db_path: str | Path = ":memory:", create_tables: bool = False, logger: logging.Logger | None = None, + session_settings: SessionSettings | None = None, **kwargs, ): """Initialize the AdvancedSQLiteSession. @@ -36,7 +38,12 @@ def __init__( logger: The logger to use. Defaults to the module logger **kwargs: Additional keyword arguments to pass to the superclass """ # noqa: E501 - super().__init__(session_id, db_path, **kwargs) + super().__init__( + session_id=session_id, + db_path=db_path, + session_settings=session_settings, + **kwargs, + ) if create_tables: self._init_structure_tables() self._current_branch_id = "main" @@ -132,12 +139,15 @@ async def get_items( """Get items from current or specified branch. Args: - limit: Maximum number of items to return. If None, returns all items. + limit: Maximum number of items to return. If None, uses session_settings.limit. branch_id: Branch to get items from. If None, uses current branch. Returns: List of conversation items from the specified branch. """ + # Use session settings limit if no explicit limit provided + session_limit = limit if limit is not None else self.session_settings.limit + if branch_id is None: branch_id = self._current_branch_id @@ -148,7 +158,7 @@ def _get_all_items_sync(): # TODO: Refactor SQLiteSession to use asyncio.Lock instead of threading.Lock and update this code # noqa: E501 with self._lock if self._is_memory_db else threading.Lock(): with closing(conn.cursor()) as cursor: - if limit is None: + if session_limit is None: cursor.execute( """ SELECT m.message_data @@ -169,11 +179,11 @@ def _get_all_items_sync(): ORDER BY s.sequence_number DESC LIMIT ? """, - (self.session_id, branch_id, limit), + (self.session_id, branch_id, session_limit), ) rows = cursor.fetchall() - if limit is not None: + if session_limit is not None: rows = list(reversed(rows)) items = [] @@ -194,7 +204,7 @@ def _get_items_sync(): with self._lock if self._is_memory_db else threading.Lock(): with closing(conn.cursor()) as cursor: # Get message IDs in correct order for this branch - if limit is None: + if session_limit is None: cursor.execute( """ SELECT m.message_data @@ -215,11 +225,11 @@ def _get_items_sync(): ORDER BY s.sequence_number DESC LIMIT ? """, - (self.session_id, branch_id, limit), + (self.session_id, branch_id, session_limit), ) rows = cursor.fetchall() - if limit is not None: + if session_limit is not None: rows = list(reversed(rows)) items = [] diff --git a/src/agents/extensions/memory/dapr_session.py b/src/agents/extensions/memory/dapr_session.py index 1242b5353..a520c1997 100644 --- a/src/agents/extensions/memory/dapr_session.py +++ b/src/agents/extensions/memory/dapr_session.py @@ -40,6 +40,7 @@ from ...items import TResponseInputItem from ...logger import logger from ...memory.session import SessionABC +from ...memory.session_settings import SessionSettings # Type alias for consistency levels ConsistencyLevel = Literal["eventual", "strong"] @@ -64,6 +65,7 @@ def __init__( dapr_client: DaprClient, ttl: int | None = None, consistency: ConsistencyLevel = DAPR_CONSISTENCY_EVENTUAL, + session_settings: SessionSettings | None = None, ): """Initializes a new DaprSession. @@ -77,8 +79,11 @@ def __init__( consistency (ConsistencyLevel, optional): Consistency level for state operations. Use DAPR_CONSISTENCY_EVENTUAL or DAPR_CONSISTENCY_STRONG constants. Defaults to DAPR_CONSISTENCY_EVENTUAL. + session_settings (SessionSettings | None): Session configuration settings including + default limit for retrieving items. If None, uses default SessionSettings(). """ self.session_id = session_id + self.session_settings = session_settings or SessionSettings() self._dapr_client = dapr_client self._state_store_name = state_store_name self._ttl = ttl @@ -97,6 +102,7 @@ def from_address( *, state_store_name: str, dapr_address: str = "localhost:50001", + session_settings: SessionSettings | None = None, **kwargs: Any, ) -> DaprSession: """Create a session from a Dapr sidecar address. @@ -105,6 +111,8 @@ def from_address( session_id (str): Conversation ID. state_store_name (str): Name of the Dapr state store component. dapr_address (str): Dapr sidecar gRPC address. Defaults to "localhost:50001". + session_settings (SessionSettings | None): Session configuration settings including + default limit for retrieving items. If None, uses default SessionSettings(). **kwargs: Additional keyword arguments forwarded to the main constructor (e.g., ttl, consistency). @@ -119,7 +127,11 @@ def from_address( """ dapr_client = DaprClient(address=dapr_address) session = cls( - session_id, state_store_name=state_store_name, dapr_client=dapr_client, **kwargs + session_id, + state_store_name=state_store_name, + dapr_client=dapr_client, + session_settings=session_settings, + **kwargs, ) session._owns_client = True # We created the client, so we own it return session @@ -222,12 +234,15 @@ async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: """Retrieve the conversation history for this session. Args: - limit: Maximum number of items to retrieve. If None, retrieves all items. + limit: Maximum number of items to retrieve. If None, uses session_settings.limit. When specified, returns the latest N items in chronological order. Returns: List of input items representing the conversation history """ + # Use session settings limit if no explicit limit provided + session_limit = limit if limit is not None else self.session_settings.limit + async with self._lock: # Get messages from state store with consistency level response = await self._dapr_client.get_state( @@ -239,10 +254,10 @@ async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: messages = self._decode_messages(response.data) if not messages: return [] - if limit is not None: - if limit <= 0: + if session_limit is not None: + if session_limit <= 0: return [] - messages = messages[-limit:] + messages = messages[-session_limit:] items: list[TResponseInputItem] = [] for msg in messages: try: diff --git a/src/agents/extensions/memory/encrypt_session.py b/src/agents/extensions/memory/encrypt_session.py index 1fc032e47..b485329e2 100644 --- a/src/agents/extensions/memory/encrypt_session.py +++ b/src/agents/extensions/memory/encrypt_session.py @@ -38,6 +38,7 @@ from ...items import TResponseInputItem from ...memory.session import SessionABC +from ...memory.session_settings import SessionSettings class EncryptedEnvelope(TypedDict): @@ -135,6 +136,16 @@ def __init__( def __getattr__(self, name): return getattr(self.underlying_session, name) + @property + def session_settings(self) -> SessionSettings: + """Get session settings from the underlying session.""" + return self.underlying_session.session_settings + + @session_settings.setter + def session_settings(self, value: SessionSettings) -> None: + """Set session settings on the underlying session.""" + self.underlying_session.session_settings = value + def _wrap(self, item: TResponseInputItem) -> EncryptedEnvelope: if isinstance(item, dict): payload = item diff --git a/src/agents/extensions/memory/redis_session.py b/src/agents/extensions/memory/redis_session.py index bb157f7b9..96cab1011 100644 --- a/src/agents/extensions/memory/redis_session.py +++ b/src/agents/extensions/memory/redis_session.py @@ -36,6 +36,7 @@ from ...items import TResponseInputItem from ...memory.session import SessionABC +from ...memory.session_settings import SessionSettings class RedisSession(SessionABC): @@ -48,6 +49,7 @@ def __init__( redis_client: Redis, key_prefix: str = "agents:session", ttl: int | None = None, + session_settings: SessionSettings | None = None, ): """Initializes a new RedisSession. @@ -58,8 +60,11 @@ def __init__( Defaults to "agents:session". ttl (int | None, optional): Time-to-live in seconds for session data. If None, data persists indefinitely. Defaults to None. + session_settings (SessionSettings | None): Session configuration settings including + default limit for retrieving items. If None, uses default SessionSettings(). """ self.session_id = session_id + self.session_settings = session_settings or SessionSettings() self._redis = redis_client self._key_prefix = key_prefix self._ttl = ttl @@ -78,6 +83,7 @@ def from_url( *, url: str, redis_kwargs: dict[str, Any] | None = None, + session_settings: SessionSettings | None = None, **kwargs: Any, ) -> RedisSession: """Create a session from a Redis URL string. @@ -87,6 +93,8 @@ def from_url( url (str): Redis URL, e.g. "redis://localhost:6379/0" or "rediss://host:6380". redis_kwargs (dict[str, Any] | None): Additional keyword arguments forwarded to redis.asyncio.from_url. + session_settings (SessionSettings | None): Session configuration settings including + default limit for retrieving items. If None, uses default SessionSettings(). **kwargs: Additional keyword arguments forwarded to the main constructor (e.g., key_prefix, ttl, etc.). @@ -96,7 +104,12 @@ def from_url( redis_kwargs = redis_kwargs or {} redis_client = redis.from_url(url, **redis_kwargs) - session = cls(session_id, redis_client=redis_client, **kwargs) + session = cls( + session_id, + redis_client=redis_client, + session_settings=session_settings, + **kwargs, + ) session._owns_client = True # We created the client, so we own it return session @@ -129,22 +142,25 @@ async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: """Retrieve the conversation history for this session. Args: - limit: Maximum number of items to retrieve. If None, retrieves all items. + limit: Maximum number of items to retrieve. If None, uses session_settings.limit. When specified, returns the latest N items in chronological order. Returns: List of input items representing the conversation history """ + # Use session settings limit if no explicit limit provided + session_limit = limit if limit is not None else self.session_settings.limit + async with self._lock: - if limit is None: + if session_limit is None: # Get all messages in chronological order raw_messages = await self._redis.lrange(self._messages_key, 0, -1) # type: ignore[misc] # Redis library returns Union[Awaitable[T], T] in async context else: - if limit <= 0: + if session_limit <= 0: return [] # Get the latest N messages (Redis list is ordered chronologically) # Use negative indices to get from the end - Redis uses -N to -1 for last N items - raw_messages = await self._redis.lrange(self._messages_key, -limit, -1) # type: ignore[misc] # Redis library returns Union[Awaitable[T], T] in async context + raw_messages = await self._redis.lrange(self._messages_key, -session_limit, -1) # type: ignore[misc] # Redis library returns Union[Awaitable[T], T] in async context items: list[TResponseInputItem] = [] for raw_msg in raw_messages: diff --git a/src/agents/extensions/memory/sqlalchemy_session.py b/src/agents/extensions/memory/sqlalchemy_session.py index d9e52e391..bd7c37e69 100644 --- a/src/agents/extensions/memory/sqlalchemy_session.py +++ b/src/agents/extensions/memory/sqlalchemy_session.py @@ -47,6 +47,7 @@ from ...items import TResponseInputItem from ...memory.session import SessionABC +from ...memory.session_settings import SessionSettings class SQLAlchemySession(SessionABC): @@ -64,6 +65,7 @@ def __init__( create_tables: bool = False, sessions_table: str = "agent_sessions", messages_table: str = "agent_messages", + session_settings: SessionSettings | None = None, ): """Initializes a new SQLAlchemySession. @@ -77,8 +79,10 @@ def __init__( development and testing when migrations aren't used. sessions_table (str, optional): Override the default table name for sessions if needed. messages_table (str, optional): Override the default table name for messages if needed. + session_settings (SessionSettings | None, optional): Session configuration settings """ self.session_id = session_id + self.session_settings = session_settings or SessionSettings() self._engine = engine self._lock = asyncio.Lock() @@ -142,6 +146,7 @@ def from_url( *, url: str, engine_kwargs: dict[str, Any] | None = None, + session_settings: SessionSettings | None = None, **kwargs: Any, ) -> SQLAlchemySession: """Create a session from a database URL string. @@ -151,6 +156,8 @@ def from_url( url (str): Any SQLAlchemy async URL, e.g. "postgresql+asyncpg://user:pass@host/db". engine_kwargs (dict[str, Any] | None): Additional keyword arguments forwarded to sqlalchemy.ext.asyncio.create_async_engine. + session_settings (SessionSettings | None): Session configuration settings including + default limit for retrieving items. If None, uses default SessionSettings(). **kwargs: Additional keyword arguments forwarded to the main constructor (e.g., create_tables, custom table names, etc.). @@ -159,7 +166,7 @@ def from_url( """ engine_kwargs = engine_kwargs or {} engine = create_async_engine(url, **engine_kwargs) - return cls(session_id, engine=engine, **kwargs) + return cls(session_id, engine=engine, session_settings=session_settings, **kwargs) async def _serialize_item(self, item: TResponseInputItem) -> str: """Serialize an item to JSON string. Can be overridden by subclasses.""" @@ -183,15 +190,19 @@ async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: """Retrieve the conversation history for this session. Args: - limit: Maximum number of items to retrieve. If None, retrieves all items. + limit: Maximum number of items to retrieve. If None, uses session_settings.limit. When specified, returns the latest N items in chronological order. Returns: List of input items representing the conversation history """ await self._ensure_tables() + + # Use session settings limit if no explicit limit provided + session_limit = limit if limit is not None else self.session_settings.limit + async with self._session_factory() as sess: - if limit is None: + if session_limit is None: stmt = ( select(self._messages.c.message_data) .where(self._messages.c.session_id == self.session_id) @@ -210,13 +221,13 @@ async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: self._messages.c.created_at.desc(), self._messages.c.id.desc(), ) - .limit(limit) + .limit(session_limit) ) result = await sess.execute(stmt) rows: list[str] = [row[0] for row in result.all()] - if limit is not None: + if session_limit is not None: rows.reverse() items: list[TResponseInputItem] = [] diff --git a/src/agents/memory/__init__.py b/src/agents/memory/__init__.py index 1db1598ac..6b2fdbb10 100644 --- a/src/agents/memory/__init__.py +++ b/src/agents/memory/__init__.py @@ -1,5 +1,6 @@ from .openai_conversations_session import OpenAIConversationsSession from .session import Session, SessionABC +from .session_settings import SessionSettings from .sqlite_session import SQLiteSession from .util import SessionInputCallback @@ -7,6 +8,7 @@ "Session", "SessionABC", "SessionInputCallback", + "SessionSettings", "SQLiteSession", "OpenAIConversationsSession", ] diff --git a/src/agents/memory/openai_conversations_session.py b/src/agents/memory/openai_conversations_session.py index 6a14e81a0..5803de347 100644 --- a/src/agents/memory/openai_conversations_session.py +++ b/src/agents/memory/openai_conversations_session.py @@ -6,6 +6,7 @@ from ..items import TResponseInputItem from .session import SessionABC +from .session_settings import SessionSettings async def start_openai_conversations_session(openai_client: AsyncOpenAI | None = None) -> str: @@ -25,14 +26,40 @@ def __init__( *, conversation_id: str | None = None, openai_client: AsyncOpenAI | None = None, + session_settings: SessionSettings | None = None, ): self._session_id: str | None = conversation_id + self.session_settings = session_settings or SessionSettings() _openai_client = openai_client if _openai_client is None: _openai_client = get_default_openai_client() or AsyncOpenAI() # this never be None here self._openai_client: AsyncOpenAI = _openai_client + @property + def session_id(self) -> str: + """Get the session ID (conversation ID). + + Returns: + The conversation ID for this session. + + Raises: + ValueError: If the session has not been initialized yet. + Call any session method (get_items, add_items, etc.) first + to trigger lazy initialization. + """ + if self._session_id is None: + raise ValueError( + "Session ID not yet available. The session is lazily initialized " + "on first API call. Call get_items(), add_items(), or similar first." + ) + return self._session_id + + @session_id.setter + def session_id(self, value: str) -> None: + """Set the session ID (conversation ID).""" + self._session_id = value + async def _get_session_id(self) -> str: if self._session_id is None: self._session_id = await start_openai_conversations_session(self._openai_client) @@ -43,8 +70,10 @@ async def _clear_session_id(self) -> None: async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: session_id = await self._get_session_id() + # Use session settings limit if no explicit limit provided + session_limit = limit if limit is not None else self.session_settings.limit all_items = [] - if limit is None: + if session_limit is None: async for item in self._openai_client.conversations.items.list( conversation_id=session_id, order="asc", @@ -54,12 +83,12 @@ async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: else: async for item in self._openai_client.conversations.items.list( conversation_id=session_id, - limit=limit, + limit=session_limit, order="desc", ): # calling model_dump() to make this serializable all_items.append(item.model_dump(exclude_unset=True)) - if limit is not None and len(all_items) >= limit: + if session_limit is not None and len(all_items) >= session_limit: break all_items.reverse() diff --git a/src/agents/memory/session.py b/src/agents/memory/session.py index 9c85af6dd..98b621f6a 100644 --- a/src/agents/memory/session.py +++ b/src/agents/memory/session.py @@ -5,6 +5,7 @@ if TYPE_CHECKING: from ..items import TResponseInputItem + from .session_settings import SessionSettings @runtime_checkable @@ -16,6 +17,7 @@ class Session(Protocol): """ session_id: str + session_settings: SessionSettings async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: """Retrieve the conversation history for this session. @@ -61,6 +63,7 @@ class SessionABC(ABC): """ session_id: str + session_settings: SessionSettings @abstractmethod async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: diff --git a/src/agents/memory/session_settings.py b/src/agents/memory/session_settings.py new file mode 100644 index 000000000..6c5c325d6 --- /dev/null +++ b/src/agents/memory/session_settings.py @@ -0,0 +1,50 @@ +"""Session configuration settings.""" + +from __future__ import annotations + +import dataclasses +from dataclasses import fields, replace +from typing import Any + +from pydantic import BaseModel +from pydantic.dataclasses import dataclass + + +@dataclass +class SessionSettings: + """Settings for session operations. + + This class holds optional session configuration parameters that can be used + when interacting with session methods. + """ + + limit: int | None = None + """Maximum number of items to retrieve. If None, retrieves all items.""" + + def resolve(self, override: SessionSettings | None) -> SessionSettings: + """Produce a new SessionSettings by overlaying any non-None values from the + override on top of this instance.""" + if override is None: + return self + + changes = { + field.name: getattr(override, field.name) + for field in fields(self) + if getattr(override, field.name) is not None + } + + return replace(self, **changes) + + def to_json_dict(self) -> dict[str, Any]: + """Convert settings to a JSON-serializable dictionary.""" + dataclass_dict = dataclasses.asdict(self) + + json_dict: dict[str, Any] = {} + + for field_name, value in dataclass_dict.items(): + if isinstance(value, BaseModel): + json_dict[field_name] = value.model_dump(mode="json") + else: + json_dict[field_name] = value + + return json_dict diff --git a/src/agents/memory/sqlite_session.py b/src/agents/memory/sqlite_session.py index 2c2386ec7..f3c366c1e 100644 --- a/src/agents/memory/sqlite_session.py +++ b/src/agents/memory/sqlite_session.py @@ -8,6 +8,7 @@ from ..items import TResponseInputItem from .session import SessionABC +from .session_settings import SessionSettings class SQLiteSession(SessionABC): @@ -24,6 +25,7 @@ def __init__( db_path: str | Path = ":memory:", sessions_table: str = "agent_sessions", messages_table: str = "agent_messages", + session_settings: SessionSettings | None = None, ): """Initialize the SQLite session. @@ -33,8 +35,11 @@ def __init__( sessions_table: Name of the table to store session metadata. Defaults to 'agent_sessions' messages_table: Name of the table to store message data. Defaults to 'agent_messages' + session_settings: Session configuration settings including default limit for + retrieving items. If None, uses default SessionSettings(). """ self.session_id = session_id + self.session_settings = session_settings or SessionSettings() self.db_path = db_path self.sessions_table = sessions_table self.messages_table = messages_table @@ -111,17 +116,19 @@ async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: """Retrieve the conversation history for this session. Args: - limit: Maximum number of items to retrieve. If None, retrieves all items. + limit: Maximum number of items to retrieve. If None, uses session_settings.limit. When specified, returns the latest N items in chronological order. Returns: List of input items representing the conversation history """ + # Use session settings limit if no explicit limit provided + session_limit = limit if limit is not None else self.session_settings.limit def _get_items_sync(): conn = self._get_connection() with self._lock if self._is_memory_db else threading.Lock(): - if limit is None: + if session_limit is None: # Fetch all items in chronological order cursor = conn.execute( f""" @@ -140,13 +147,13 @@ def _get_items_sync(): ORDER BY created_at DESC LIMIT ? """, - (self.session_id, limit), + (self.session_id, session_limit), ) rows = cursor.fetchall() # Reverse to get chronological order when using DESC - if limit is not None: + if session_limit is not None: rows = list(reversed(rows)) items = [] diff --git a/src/agents/run.py b/src/agents/run.py index e772b254e..fbcca171b 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -59,7 +59,7 @@ ) from .lifecycle import AgentHooksBase, RunHooks, RunHooksBase from .logger import logger -from .memory import Session, SessionInputCallback +from .memory import Session, SessionInputCallback, SessionSettings from .model_settings import ModelSettings from .models.interface import Model, ModelProvider from .models.multi_provider import MultiProvider @@ -197,6 +197,11 @@ class RunConfig: settings. """ + session_settings: SessionSettings | None = None + """Configure session settings. Any non-null values will override the session's default + settings. Used to control session behavior like the number of items to retrieve. + """ + handoff_input_filter: HandoffInputFilter | None = None """A global input filter to apply to all handoffs. If `Handoff.input_filter` is set, then that will take precedence. The input filter allows you to edit the inputs that are sent to the new @@ -565,7 +570,9 @@ async def run( # Keep original user input separate from session-prepared input original_user_input = input prepared_input = await self._prepare_input_with_session( - input, session, run_config.session_input_callback + input, + session, + run_config, ) tool_use_tracker = AgentToolUseTracker() @@ -1094,7 +1101,9 @@ async def _start_streaming( try: # Prepare input with session if enabled prepared_input = await AgentRunner._prepare_input_with_session( - starting_input, session, run_config.session_input_callback + starting_input, + session, + run_config, ) # Update the streamed result with the prepared input @@ -1941,12 +1950,14 @@ async def _prepare_input_with_session( cls, input: str | list[TResponseInputItem], session: Session | None, - session_input_callback: SessionInputCallback | None, + run_config: RunConfig, ) -> str | list[TResponseInputItem]: """Prepare input by combining it with session history if enabled.""" if session is None: return input + session_input_callback: SessionInputCallback | None = run_config.session_input_callback + # If the user doesn't specify an input callback and pass a list as input if isinstance(input, list) and not session_input_callback: raise UserError( @@ -1958,7 +1969,14 @@ async def _prepare_input_with_session( ) # Get previous conversation history - history = await session.get_items() + # Resolve session settings: session defaults + run config overrides + session_settings = session.session_settings + if run_config.session_settings is not None: + session_settings = session_settings.resolve(run_config.session_settings) + + history = await session.get_items( + limit=session_settings.limit, + ) # Convert input to list format new_input_list = ItemHelpers.input_to_new_input_list(input) diff --git a/tests/extensions/memory/test_advanced_sqlite_session.py b/tests/extensions/memory/test_advanced_sqlite_session.py index 40edb99fe..a8c592c38 100644 --- a/tests/extensions/memory/test_advanced_sqlite_session.py +++ b/tests/extensions/memory/test_advanced_sqlite_session.py @@ -986,3 +986,144 @@ async def test_tool_execution_integration(agent: Agent): assert len(tool_usage) > 0 session.close() + + +# ============================================================================ +# SessionSettings Tests +# ============================================================================ + + +async def test_session_settings_default(): + """Test that session_settings defaults to empty SessionSettings.""" + from agents.memory import SessionSettings + + session = AdvancedSQLiteSession(session_id="default_settings_test", create_tables=True) + + # Should have default SessionSettings (inherited from SQLiteSession) + assert isinstance(session.session_settings, SessionSettings) + assert session.session_settings.limit is None + + session.close() + + +async def test_session_settings_constructor(): + """Test passing session_settings via constructor.""" + from agents.memory import SessionSettings + + session = AdvancedSQLiteSession( + session_id="constructor_settings_test", + create_tables=True, + session_settings=SessionSettings(limit=5), + ) + + assert session.session_settings.limit == 5 + + session.close() + + +async def test_get_items_uses_session_settings_limit(): + """Test that get_items uses session_settings.limit as default.""" + from agents.memory import SessionSettings + + session = AdvancedSQLiteSession( + session_id="uses_settings_limit_test", + create_tables=True, + session_settings=SessionSettings(limit=3), + ) + + # Add 5 items + items: list[TResponseInputItem] = [ + {"role": "user", "content": f"Message {i}"} for i in range(5) + ] + await session.add_items(items) + + # get_items() with no limit should use session_settings.limit=3 + retrieved = await session.get_items() + assert len(retrieved) == 3 + # Should get the last 3 items + assert retrieved[0].get("content") == "Message 2" + assert retrieved[1].get("content") == "Message 3" + assert retrieved[2].get("content") == "Message 4" + + session.close() + + +async def test_get_items_explicit_limit_overrides_session_settings(): + """Test that explicit limit parameter overrides session_settings.""" + from agents.memory import SessionSettings + + session = AdvancedSQLiteSession( + session_id="explicit_override_test", + create_tables=True, + session_settings=SessionSettings(limit=5), + ) + + # Add 10 items + items: list[TResponseInputItem] = [ + {"role": "user", "content": f"Message {i}"} for i in range(10) + ] + await session.add_items(items) + + # Explicit limit=2 should override session_settings.limit=5 + retrieved = await session.get_items(limit=2) + assert len(retrieved) == 2 + assert retrieved[0].get("content") == "Message 8" + assert retrieved[1].get("content") == "Message 9" + + session.close() + + +async def test_session_settings_resolve(): + """Test SessionSettings.resolve() method.""" + from agents.memory import SessionSettings + + base = SessionSettings(limit=100) + override = SessionSettings(limit=50) + + final = base.resolve(override) + + assert final.limit == 50 # Override wins + assert base.limit == 100 # Original unchanged + + # Resolving with None returns self + final_none = base.resolve(None) + assert final_none.limit == 100 + + +async def test_runner_with_session_settings_override(agent: Agent): + """Test that RunConfig can override session's default settings.""" + from agents import RunConfig + from agents.memory import SessionSettings + + # Session with default limit=100 + session = AdvancedSQLiteSession( + session_id="runner_override_test", + create_tables=True, + session_settings=SessionSettings(limit=100), + ) + + # Add some history + items: list[TResponseInputItem] = [{"role": "user", "content": f"Turn {i}"} for i in range(10)] + await session.add_items(items) + + # Use RunConfig to override limit to 2 + assert isinstance(agent.model, FakeModel) + agent.model.set_next_output([get_text_message("Got it")]) + + await Runner.run( + agent, + "New question", + session=session, + run_config=RunConfig( + session_settings=SessionSettings(limit=2) # Override to 2 + ), + ) + + # Verify the agent received only the last 2 history items + new question + last_input = agent.model.last_turn_args["input"] + # Filter out the new "New question" input + history_items = [item for item in last_input if item.get("content") != "New question"] + # Should have 2 history items (last two from the 10 we added) + assert len(history_items) == 2 + + session.close() diff --git a/tests/extensions/memory/test_dapr_session.py b/tests/extensions/memory/test_dapr_session.py index 26e8743b2..ff23f06ea 100644 --- a/tests/extensions/memory/test_dapr_session.py +++ b/tests/extensions/memory/test_dapr_session.py @@ -830,3 +830,164 @@ async def test_context_manager(fake_dapr_client: FakeDaprClient): assert len(items) == 1 # Close should have been called automatically (though fake client doesn't track this) + + +# ============================================================================ +# SessionSettings Tests +# ============================================================================ + + +async def test_session_settings_default(fake_dapr_client: FakeDaprClient): + """Test that session_settings defaults to empty SessionSettings.""" + from agents.memory import SessionSettings + + session = await _create_test_session(fake_dapr_client) + + try: + # Should have default SessionSettings + assert isinstance(session.session_settings, SessionSettings) + assert session.session_settings.limit is None + finally: + await session.close() + + +async def test_session_settings_constructor(fake_dapr_client: FakeDaprClient): + """Test passing session_settings via constructor.""" + from agents.memory import SessionSettings + + session = DaprSession( + session_id="settings_test", + state_store_name="statestore", + dapr_client=fake_dapr_client, # type: ignore[arg-type] + session_settings=SessionSettings(limit=5), + ) + + try: + assert session.session_settings.limit == 5 + finally: + await session.close() + + +async def test_get_items_uses_session_settings_limit(fake_dapr_client: FakeDaprClient): + """Test that get_items uses session_settings.limit as default.""" + from agents.memory import SessionSettings + + session = DaprSession( + session_id="uses_settings_limit_test", + state_store_name="statestore", + dapr_client=fake_dapr_client, # type: ignore[arg-type] + session_settings=SessionSettings(limit=3), + ) + + try: + await session.clear_session() + + # Add 5 items + items: list[TResponseInputItem] = [ + {"role": "user", "content": f"Message {i}"} for i in range(5) + ] + await session.add_items(items) + + # get_items() with no limit should use session_settings.limit=3 + retrieved = await session.get_items() + assert len(retrieved) == 3 + # Should get the last 3 items + assert retrieved[0].get("content") == "Message 2" + assert retrieved[1].get("content") == "Message 3" + assert retrieved[2].get("content") == "Message 4" + finally: + await session.close() + + +async def test_get_items_explicit_limit_overrides_session_settings( + fake_dapr_client: FakeDaprClient, +): + """Test that explicit limit parameter overrides session_settings.""" + from agents.memory import SessionSettings + + session = DaprSession( + session_id="explicit_override_test", + state_store_name="statestore", + dapr_client=fake_dapr_client, # type: ignore[arg-type] + session_settings=SessionSettings(limit=5), + ) + + try: + await session.clear_session() + + # Add 10 items + items: list[TResponseInputItem] = [ + {"role": "user", "content": f"Message {i}"} for i in range(10) + ] + await session.add_items(items) + + # Explicit limit=2 should override session_settings.limit=5 + retrieved = await session.get_items(limit=2) + assert len(retrieved) == 2 + assert retrieved[0].get("content") == "Message 8" + assert retrieved[1].get("content") == "Message 9" + finally: + await session.close() + + +async def test_session_settings_resolve(): + """Test SessionSettings.resolve() method.""" + from agents.memory import SessionSettings + + base = SessionSettings(limit=100) + override = SessionSettings(limit=50) + + final = base.resolve(override) + + assert final.limit == 50 # Override wins + assert base.limit == 100 # Original unchanged + + # Resolving with None returns self + final_none = base.resolve(None) + assert final_none.limit == 100 + + +async def test_runner_with_session_settings_override(fake_dapr_client: FakeDaprClient): + """Test that RunConfig can override session's default settings.""" + from agents import Agent, RunConfig, Runner + from agents.memory import SessionSettings + from tests.fake_model import FakeModel + from tests.test_responses import get_text_message + + session = DaprSession( + session_id="runner_override_test", + state_store_name="statestore", + dapr_client=fake_dapr_client, # type: ignore[arg-type] + session_settings=SessionSettings(limit=100), + ) + + try: + await session.clear_session() + + # Add some history + items: list[TResponseInputItem] = [ + {"role": "user", "content": f"Turn {i}"} for i in range(10) + ] + await session.add_items(items) + + model = FakeModel() + agent = Agent(name="test", model=model) + model.set_next_output([get_text_message("Got it")]) + + await Runner.run( + agent, + "New question", + session=session, + run_config=RunConfig( + session_settings=SessionSettings(limit=2) # Override to 2 + ), + ) + + # Verify the agent received only the last 2 history items + new question + last_input = model.last_turn_args["input"] + # Filter out the new "New question" input + history_items = [item for item in last_input if item.get("content") != "New question"] + # Should have 2 history items (last two from the 10 we added) + assert len(history_items) == 2 + finally: + await session.close() diff --git a/tests/extensions/memory/test_encrypt_session.py b/tests/extensions/memory/test_encrypt_session.py index 5eb1d9b53..c71bdc952 100644 --- a/tests/extensions/memory/test_encrypt_session.py +++ b/tests/extensions/memory/test_encrypt_session.py @@ -331,3 +331,149 @@ async def test_encrypted_session_delegation(): assert items[0].get("content") == "Test delegation" underlying_session.close() + + +# ============================================================================ +# SessionSettings Tests +# ============================================================================ + + +async def test_session_settings_delegated_to_underlying(encryption_key: str): + """Test that session_settings is correctly delegated to underlying session.""" + from agents.memory import SessionSettings + + temp_dir = tempfile.mkdtemp() + db_path = Path(temp_dir) / "test_settings.db" + underlying = SQLiteSession("test_session", db_path, session_settings=SessionSettings(limit=5)) + + session = EncryptedSession( + session_id="test_session", + underlying_session=underlying, + encryption_key=encryption_key, + ) + + # session_settings should be accessible through EncryptedSession + assert session.session_settings.limit == 5 + + underlying.close() + + +async def test_session_settings_get_items_uses_underlying_limit(encryption_key: str): + """Test that get_items uses underlying session's session_settings.limit.""" + from agents.memory import SessionSettings + + temp_dir = tempfile.mkdtemp() + db_path = Path(temp_dir) / "test_settings_limit.db" + underlying = SQLiteSession("test_session", db_path, session_settings=SessionSettings(limit=3)) + + session = EncryptedSession( + session_id="test_session", + underlying_session=underlying, + encryption_key=encryption_key, + ) + + # Add 5 items + items: list[TResponseInputItem] = [ + {"role": "user", "content": f"Message {i}"} for i in range(5) + ] + await session.add_items(items) + + # get_items() with no limit should use underlying session_settings.limit=3 + retrieved = await session.get_items() + assert len(retrieved) == 3 + # Should get the last 3 items + assert retrieved[0].get("content") == "Message 2" + assert retrieved[1].get("content") == "Message 3" + assert retrieved[2].get("content") == "Message 4" + + underlying.close() + + +async def test_session_settings_explicit_limit_overrides_settings(encryption_key: str): + """Test that explicit limit parameter overrides session_settings.""" + from agents.memory import SessionSettings + + temp_dir = tempfile.mkdtemp() + db_path = Path(temp_dir) / "test_override.db" + underlying = SQLiteSession("test_session", db_path, session_settings=SessionSettings(limit=5)) + + session = EncryptedSession( + session_id="test_session", + underlying_session=underlying, + encryption_key=encryption_key, + ) + + # Add 10 items + items: list[TResponseInputItem] = [ + {"role": "user", "content": f"Message {i}"} for i in range(10) + ] + await session.add_items(items) + + # Explicit limit=2 should override session_settings.limit=5 + retrieved = await session.get_items(limit=2) + assert len(retrieved) == 2 + assert retrieved[0].get("content") == "Message 8" + assert retrieved[1].get("content") == "Message 9" + + underlying.close() + + +async def test_session_settings_resolve(): + """Test SessionSettings.resolve() method.""" + from agents.memory import SessionSettings + + base = SessionSettings(limit=100) + override = SessionSettings(limit=50) + + final = base.resolve(override) + + assert final.limit == 50 # Override wins + assert base.limit == 100 # Original unchanged + + # Resolving with None returns self + final_none = base.resolve(None) + assert final_none.limit == 100 + + +async def test_runner_with_session_settings_override(encryption_key: str): + """Test that RunConfig can override session's default settings.""" + from agents import Agent, RunConfig, Runner + from agents.memory import SessionSettings + from tests.fake_model import FakeModel + from tests.test_responses import get_text_message + + temp_dir = tempfile.mkdtemp() + db_path = Path(temp_dir) / "test_runner_override.db" + underlying = SQLiteSession("test_session", db_path, session_settings=SessionSettings(limit=100)) + + session = EncryptedSession( + session_id="test_session", + underlying_session=underlying, + encryption_key=encryption_key, + ) + + # Add some history + items: list[TResponseInputItem] = [{"role": "user", "content": f"Turn {i}"} for i in range(10)] + await session.add_items(items) + + model = FakeModel() + agent = Agent(name="test", model=model) + model.set_next_output([get_text_message("Got it")]) + + await Runner.run( + agent, + "New question", + session=session, + run_config=RunConfig( + session_settings=SessionSettings(limit=2) # Override to 2 + ), + ) + + # Verify the agent received only the last 2 history items + new question + last_input = model.last_turn_args["input"] + # Filter out the new "New question" input + history_items = [item for item in last_input if item.get("content") != "New question"] + # Should have 2 history items (last two from the 10 we added) + assert len(history_items) == 2 + + underlying.close() diff --git a/tests/extensions/memory/test_redis_session.py b/tests/extensions/memory/test_redis_session.py index fa7ea8692..be49c8e00 100644 --- a/tests/extensions/memory/test_redis_session.py +++ b/tests/extensions/memory/test_redis_session.py @@ -77,7 +77,7 @@ async def _create_test_session(session_id: str | None = None) -> RedisSession: # Use in-memory fake Redis for testing session = RedisSession(session_id=session_id, redis_client=fake_redis, key_prefix="test:") else: - session = RedisSession.from_url(session_id, url=REDIS_URL) + session = RedisSession.from_url(session_id, url=REDIS_URL, key_prefix="test:") # Ensure we can connect if not await session.ping(): @@ -794,3 +794,201 @@ async def test_close_method_coverage(): # This should trigger the close path for owned clients await session2.close() + + +# ============================================================================ +# SessionSettings Tests +# ============================================================================ + + +async def test_session_settings_default(): + """Test that session_settings defaults to empty SessionSettings.""" + from agents.memory import SessionSettings + + session = await _create_test_session() + + try: + # Should have default SessionSettings + assert isinstance(session.session_settings, SessionSettings) + assert session.session_settings.limit is None + finally: + await session.close() + + +async def test_session_settings_constructor(): + """Test passing session_settings via constructor.""" + from agents.memory import SessionSettings + + if USE_FAKE_REDIS: + session = RedisSession( + session_id="settings_test", + redis_client=fake_redis, + key_prefix="test:", + session_settings=SessionSettings(limit=5), + ) + else: + session = RedisSession.from_url( + "settings_test", url=REDIS_URL, session_settings=SessionSettings(limit=5) + ) + + try: + assert session.session_settings.limit == 5 + finally: + await session.close() + + +async def test_session_settings_from_url(): + """Test passing session_settings via from_url.""" + if USE_FAKE_REDIS: + pytest.skip("from_url test requires real Redis server") + + from agents.memory import SessionSettings + + session = RedisSession.from_url( + "from_url_settings_test", url=REDIS_URL, session_settings=SessionSettings(limit=10) + ) + + try: + if not await session.ping(): + pytest.skip("Redis server not available") + assert session.session_settings.limit == 10 + finally: + await session.close() + + +async def test_get_items_uses_session_settings_limit(): + """Test that get_items uses session_settings.limit as default.""" + from agents.memory import SessionSettings + + if USE_FAKE_REDIS: + session = RedisSession( + session_id="uses_settings_limit_test", + redis_client=fake_redis, + key_prefix="test:", + session_settings=SessionSettings(limit=3), + ) + else: + session = RedisSession.from_url( + "uses_settings_limit_test", url=REDIS_URL, session_settings=SessionSettings(limit=3) + ) + + try: + await session.clear_session() + + # Add 5 items + items: list[TResponseInputItem] = [ + {"role": "user", "content": f"Message {i}"} for i in range(5) + ] + await session.add_items(items) + + # get_items() with no limit should use session_settings.limit=3 + retrieved = await session.get_items() + assert len(retrieved) == 3 + # Should get the last 3 items + assert retrieved[0].get("content") == "Message 2" + assert retrieved[1].get("content") == "Message 3" + assert retrieved[2].get("content") == "Message 4" + finally: + await session.close() + + +async def test_get_items_explicit_limit_overrides_session_settings(): + """Test that explicit limit parameter overrides session_settings.""" + from agents.memory import SessionSettings + + if USE_FAKE_REDIS: + session = RedisSession( + session_id="explicit_override_test", + redis_client=fake_redis, + key_prefix="test:", + session_settings=SessionSettings(limit=5), + ) + else: + session = RedisSession.from_url( + "explicit_override_test", url=REDIS_URL, session_settings=SessionSettings(limit=5) + ) + + try: + await session.clear_session() + + # Add 10 items + items: list[TResponseInputItem] = [ + {"role": "user", "content": f"Message {i}"} for i in range(10) + ] + await session.add_items(items) + + # Explicit limit=2 should override session_settings.limit=5 + retrieved = await session.get_items(limit=2) + assert len(retrieved) == 2 + assert retrieved[0].get("content") == "Message 8" + assert retrieved[1].get("content") == "Message 9" + finally: + await session.close() + + +async def test_session_settings_resolve(): + """Test SessionSettings.resolve() method.""" + from agents.memory import SessionSettings + + base = SessionSettings(limit=100) + override = SessionSettings(limit=50) + + final = base.resolve(override) + + assert final.limit == 50 # Override wins + assert base.limit == 100 # Original unchanged + + # Resolving with None returns self + final_none = base.resolve(None) + assert final_none.limit == 100 + + +async def test_runner_with_session_settings_override(): + """Test that RunConfig can override session's default settings.""" + from agents import Agent, RunConfig, Runner + from agents.memory import SessionSettings + from tests.fake_model import FakeModel + from tests.test_responses import get_text_message + + if USE_FAKE_REDIS: + session = RedisSession( + session_id="runner_override_test", + redis_client=fake_redis, + key_prefix="test:", + session_settings=SessionSettings(limit=100), + ) + else: + session = RedisSession.from_url( + "runner_override_test", url=REDIS_URL, session_settings=SessionSettings(limit=100) + ) + + try: + await session.clear_session() + + # Add some history + items: list[TResponseInputItem] = [ + {"role": "user", "content": f"Turn {i}"} for i in range(10) + ] + await session.add_items(items) + + model = FakeModel() + agent = Agent(name="test", model=model) + model.set_next_output([get_text_message("Got it")]) + + await Runner.run( + agent, + "New question", + session=session, + run_config=RunConfig( + session_settings=SessionSettings(limit=2) # Override to 2 + ), + ) + + # Verify the agent received only the last 2 history items + new question + last_input = model.last_turn_args["input"] + # Filter out the new "New question" input + history_items = [item for item in last_input if item.get("content") != "New question"] + # Should have 2 history items (last two from the 10 we added) + assert len(history_items) == 2 + finally: + await session.close() diff --git a/tests/extensions/memory/test_sqlalchemy_session.py b/tests/extensions/memory/test_sqlalchemy_session.py index b280a000f..9e2220447 100644 --- a/tests/extensions/memory/test_sqlalchemy_session.py +++ b/tests/extensions/memory/test_sqlalchemy_session.py @@ -445,3 +445,133 @@ async def test_engine_property_is_read_only(): # Clean up await session.engine.dispose() + + +async def test_session_settings_default(): + """Test that session_settings defaults to empty SessionSettings.""" + from agents.memory import SessionSettings + + session = SQLAlchemySession.from_url("default_settings_test", url=DB_URL, create_tables=True) + + # Should have default SessionSettings + assert isinstance(session.session_settings, SessionSettings) + assert session.session_settings.limit is None + + +async def test_session_settings_from_url(): + """Test passing session_settings via from_url.""" + from agents.memory import SessionSettings + + session = SQLAlchemySession.from_url( + "from_url_settings_test", + url=DB_URL, + create_tables=True, + session_settings=SessionSettings(limit=5), + ) + + assert session.session_settings.limit == 5 + + +async def test_get_items_uses_session_settings_limit(): + """Test that get_items uses session_settings.limit as default.""" + from agents.memory import SessionSettings + + session = SQLAlchemySession.from_url( + "uses_settings_limit_test", + url=DB_URL, + create_tables=True, + session_settings=SessionSettings(limit=3), + ) + + # Add 5 items + items: list[TResponseInputItem] = [ + {"role": "user", "content": f"Message {i}"} for i in range(5) + ] + await session.add_items(items) + + # get_items() with no limit should use session_settings.limit=3 + retrieved = await session.get_items() + assert len(retrieved) == 3 + # Should get the last 3 items + assert retrieved[0].get("content") == "Message 2" + assert retrieved[1].get("content") == "Message 3" + assert retrieved[2].get("content") == "Message 4" + + +async def test_get_items_explicit_limit_overrides_session_settings(): + """Test that explicit limit parameter overrides session_settings.""" + from agents.memory import SessionSettings + + session = SQLAlchemySession.from_url( + "explicit_override_test", + url=DB_URL, + create_tables=True, + session_settings=SessionSettings(limit=5), + ) + + # Add 10 items + items: list[TResponseInputItem] = [ + {"role": "user", "content": f"Message {i}"} for i in range(10) + ] + await session.add_items(items) + + # Explicit limit=2 should override session_settings.limit=5 + retrieved = await session.get_items(limit=2) + assert len(retrieved) == 2 + assert retrieved[0].get("content") == "Message 8" + assert retrieved[1].get("content") == "Message 9" + + +async def test_session_settings_resolve(): + """Test SessionSettings.resolve() method.""" + from agents.memory import SessionSettings + + base = SessionSettings(limit=100) + override = SessionSettings(limit=50) + + final = base.resolve(override) + + assert final.limit == 50 # Override wins + assert base.limit == 100 # Original unchanged + + # Resolving with None returns self + final_none = base.resolve(None) + assert final_none.limit == 100 + + +async def test_runner_with_session_settings_override(agent: Agent): + """Test that RunConfig can override session's default settings.""" + from agents import RunConfig + from agents.memory import SessionSettings + + # Session with default limit=100 + session = SQLAlchemySession.from_url( + "runner_override_test", + url=DB_URL, + create_tables=True, + session_settings=SessionSettings(limit=100), + ) + + # Add some history + items: list[TResponseInputItem] = [{"role": "user", "content": f"Turn {i}"} for i in range(10)] + await session.add_items(items) + + # Use RunConfig to override limit to 2 + assert isinstance(agent.model, FakeModel) + agent.model.set_next_output([get_text_message("Got it")]) + + await Runner.run( + agent, + "New question", + session=session, + run_config=RunConfig( + session_settings=SessionSettings(limit=2) # Override to 2 + ), + ) + + # Verify the agent received only the last 2 history items + new question + last_input = agent.model.last_turn_args["input"] + # Filter out the new "New question" input + history_items = [item for item in last_input if item.get("content") != "New question"] + # Should have 2 history items (last two from the 10 we added) + assert len(history_items) == 2 diff --git a/tests/test_agent_as_tool.py b/tests/test_agent_as_tool.py index c28ce8fb1..bc3787c94 100644 --- a/tests/test_agent_as_tool.py +++ b/tests/test_agent_as_tool.py @@ -19,6 +19,7 @@ RunHooks, Runner, Session, + SessionSettings, TResponseInputItem, ) from agents.stream_events import AgentUpdatedStreamEvent, RawResponsesStreamEvent @@ -300,6 +301,7 @@ async def test_agent_as_tool_custom_output_extractor(monkeypatch: pytest.MonkeyP class DummySession(Session): session_id = "sess_123" + session_settings = SessionSettings() async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: return [] diff --git a/tests/test_openai_conversations_session.py b/tests/test_openai_conversations_session.py index 732c1fa2c..4acbe3b70 100644 --- a/tests/test_openai_conversations_session.py +++ b/tests/test_openai_conversations_session.py @@ -443,3 +443,32 @@ async def test_session_id_lazy_creation_consistency(self, mock_openai_client): # Conversation should only be created once mock_openai_client.conversations.create.assert_called_once() + + +# ============================================================================ +# SessionSettings Tests +# ============================================================================ + + +class TestOpenAIConversationsSessionSettings: + """Test SessionSettings integration with OpenAIConversationsSession.""" + + def test_session_settings_default(self, mock_openai_client): + """Test that session_settings defaults to empty SessionSettings.""" + from agents.memory import SessionSettings + + session = OpenAIConversationsSession(openai_client=mock_openai_client) + + # Should have default SessionSettings + assert isinstance(session.session_settings, SessionSettings) + assert session.session_settings.limit is None + + def test_session_settings_constructor(self, mock_openai_client): + """Test passing session_settings via constructor.""" + from agents.memory import SessionSettings + + session = OpenAIConversationsSession( + openai_client=mock_openai_client, session_settings=SessionSettings(limit=5) + ) + + assert session.session_settings.limit == 5 diff --git a/tests/test_session.py b/tests/test_session.py index e0328056b..9569df9eb 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -564,3 +564,148 @@ async def consume_stream(): await asyncio.wait_for(consume_stream(), timeout=5.0) session.close() + + +# ============================================================================ +# SessionSettings Tests +# ============================================================================ + + +@pytest.mark.asyncio +async def test_session_settings_default(): + """Test that session_settings defaults to empty SessionSettings.""" + from agents.memory import SessionSettings + + session = SQLiteSession("default_settings_test") + + # Should have default SessionSettings + assert isinstance(session.session_settings, SessionSettings) + assert session.session_settings.limit is None + + session.close() + + +@pytest.mark.asyncio +async def test_session_settings_constructor(): + """Test passing session_settings via constructor.""" + from agents.memory import SessionSettings + + session = SQLiteSession("constructor_settings_test", session_settings=SessionSettings(limit=5)) + + assert session.session_settings.limit == 5 + + session.close() + + +@pytest.mark.asyncio +async def test_get_items_uses_session_settings_limit(): + """Test that get_items uses session_settings.limit as default.""" + from agents.memory import SessionSettings + + with tempfile.TemporaryDirectory() as temp_dir: + db_path = Path(temp_dir) / "test_settings_limit.db" + session = SQLiteSession( + "uses_settings_limit_test", db_path, session_settings=SessionSettings(limit=3) + ) + + # Add 5 items + items: list[TResponseInputItem] = [ + {"role": "user", "content": f"Message {i}"} for i in range(5) + ] + await session.add_items(items) + + # get_items() with no limit should use session_settings.limit=3 + retrieved = await session.get_items() + assert len(retrieved) == 3 + # Should get the last 3 items + assert retrieved[0].get("content") == "Message 2" + assert retrieved[1].get("content") == "Message 3" + assert retrieved[2].get("content") == "Message 4" + + session.close() + + +@pytest.mark.asyncio +async def test_get_items_explicit_limit_overrides_session_settings(): + """Test that explicit limit parameter overrides session_settings.""" + from agents.memory import SessionSettings + + with tempfile.TemporaryDirectory() as temp_dir: + db_path = Path(temp_dir) / "test_override.db" + session = SQLiteSession( + "explicit_override_test", db_path, session_settings=SessionSettings(limit=5) + ) + + # Add 10 items + items: list[TResponseInputItem] = [ + {"role": "user", "content": f"Message {i}"} for i in range(10) + ] + await session.add_items(items) + + # Explicit limit=2 should override session_settings.limit=5 + retrieved = await session.get_items(limit=2) + assert len(retrieved) == 2 + assert retrieved[0].get("content") == "Message 8" + assert retrieved[1].get("content") == "Message 9" + + session.close() + + +@pytest.mark.asyncio +async def test_session_settings_resolve(): + """Test SessionSettings.resolve() method.""" + from agents.memory import SessionSettings + + base = SessionSettings(limit=100) + override = SessionSettings(limit=50) + + final = base.resolve(override) + + assert final.limit == 50 # Override wins + assert base.limit == 100 # Original unchanged + + # Resolving with None returns self + final_none = base.resolve(None) + assert final_none.limit == 100 + + +@pytest.mark.asyncio +async def test_runner_with_session_settings_override(): + """Test that RunConfig can override session's default settings.""" + from agents.memory import SessionSettings + + with tempfile.TemporaryDirectory() as temp_dir: + db_path = Path(temp_dir) / "test_runner_override.db" + + # Session with default limit=100 + session = SQLiteSession( + "runner_override_test", db_path, session_settings=SessionSettings(limit=100) + ) + + # Add some history + items: list[TResponseInputItem] = [ + {"role": "user", "content": f"Turn {i}"} for i in range(10) + ] + await session.add_items(items) + + model = FakeModel() + agent = Agent(name="test", model=model) + model.set_next_output([get_text_message("Got it")]) + + await Runner.run( + agent, + "New question", + session=session, + run_config=RunConfig( + session_settings=SessionSettings(limit=2) # Override to 2 + ), + ) + + # Verify the agent received only the last 2 history items + new question + last_input = model.last_turn_args["input"] + # Filter out the new "New question" input + history_items = [item for item in last_input if item.get("content") != "New question"] + # Should have 2 history items (last two from the 10 we added) + assert len(history_items) == 2 + + session.close() diff --git a/tests/test_session_limit.py b/tests/test_session_limit.py new file mode 100644 index 000000000..f8625f05c --- /dev/null +++ b/tests/test_session_limit.py @@ -0,0 +1,176 @@ +"""Test session_limit parameter functionality via SessionSettings.""" + +import tempfile +from pathlib import Path + +import pytest + +from agents import Agent, RunConfig, SQLiteSession +from agents.memory import SessionSettings +from tests.fake_model import FakeModel +from tests.test_responses import get_text_message +from tests.test_session import run_agent_async + + +@pytest.mark.parametrize("runner_method", ["run", "run_sync", "run_streamed"]) +@pytest.mark.asyncio +async def test_session_limit_parameter(runner_method): + """Test that session_limit parameter correctly limits conversation history + retrieved from session across all Runner methods.""" + with tempfile.TemporaryDirectory() as temp_dir: + db_path = Path(temp_dir) / "test_limit.db" + session_id = "limit_test" + session = SQLiteSession(session_id, db_path) + + model = FakeModel() + agent = Agent(name="test", model=model) + + # Build up a longer conversation history + model.set_next_output([get_text_message("Reply 1")]) + await run_agent_async(runner_method, agent, "Message 1", session=session) + + model.set_next_output([get_text_message("Reply 2")]) + await run_agent_async(runner_method, agent, "Message 2", session=session) + + model.set_next_output([get_text_message("Reply 3")]) + await run_agent_async(runner_method, agent, "Message 3", session=session) + + # Verify we have 6 items in total (3 user + 3 assistant) + all_items = await session.get_items() + assert len(all_items) == 6 + + # Test session_limit via RunConfig - should only get last 2 history items + new input + model.set_next_output([get_text_message("Reply 4")]) + await run_agent_async( + runner_method, + agent, + "Message 4", + session=session, + run_config=RunConfig(session_settings=SessionSettings(limit=2)), + ) + + # Verify model received limited history + last_input = model.last_turn_args["input"] + # Should have: 2 history items + 1 new message = 3 total + assert len(last_input) == 3 + # First item should be "Message 3" (not Message 1 or 2) + assert last_input[0].get("content") == "Message 3" + # Assistant message has content as a list + assert last_input[1].get("content")[0]["text"] == "Reply 3" + assert last_input[2].get("content") == "Message 4" + + session.close() + + +@pytest.mark.parametrize("runner_method", ["run", "run_sync", "run_streamed"]) +@pytest.mark.asyncio +async def test_session_limit_zero(runner_method): + """Test that session_limit=0 provides no history, only new message.""" + with tempfile.TemporaryDirectory() as temp_dir: + db_path = Path(temp_dir) / "test_limit_zero.db" + session_id = "limit_zero_test" + session = SQLiteSession(session_id, db_path) + + model = FakeModel() + agent = Agent(name="test", model=model) + + # Build conversation history + model.set_next_output([get_text_message("Reply 1")]) + await run_agent_async(runner_method, agent, "Message 1", session=session) + + model.set_next_output([get_text_message("Reply 2")]) + await run_agent_async(runner_method, agent, "Message 2", session=session) + + # Test with limit=0 - should get NO history, just new message + model.set_next_output([get_text_message("Reply 3")]) + await run_agent_async( + runner_method, + agent, + "Message 3", + session=session, + run_config=RunConfig(session_settings=SessionSettings(limit=0)), + ) + + # Verify model received only the new message + last_input = model.last_turn_args["input"] + assert len(last_input) == 1 + assert last_input[0].get("content") == "Message 3" + + session.close() + + +@pytest.mark.parametrize("runner_method", ["run", "run_sync", "run_streamed"]) +@pytest.mark.asyncio +async def test_session_limit_none_gets_all_history(runner_method): + """Test that session_limit=None retrieves entire history (default behavior).""" + with tempfile.TemporaryDirectory() as temp_dir: + db_path = Path(temp_dir) / "test_limit_none.db" + session_id = "limit_none_test" + session = SQLiteSession(session_id, db_path) + + model = FakeModel() + agent = Agent(name="test", model=model) + + # Build longer conversation + for i in range(1, 6): + model.set_next_output([get_text_message(f"Reply {i}")]) + await run_agent_async(runner_method, agent, f"Message {i}", session=session) + + # Verify 10 items in session (5 user + 5 assistant) + all_items = await session.get_items() + assert len(all_items) == 10 + + # Test with session_limit=None (default) - should get all history + model.set_next_output([get_text_message("Reply 6")]) + await run_agent_async( + runner_method, + agent, + "Message 6", + session=session, + run_config=RunConfig(session_settings=SessionSettings(limit=None)), + ) + + # Verify model received all history + new message + last_input = model.last_turn_args["input"] + assert len(last_input) == 11 # 10 history + 1 new + assert last_input[0].get("content") == "Message 1" + assert last_input[-1].get("content") == "Message 6" + + session.close() + + +@pytest.mark.parametrize("runner_method", ["run", "run_sync", "run_streamed"]) +@pytest.mark.asyncio +async def test_session_limit_larger_than_history(runner_method): + """Test that session_limit larger than history size returns all items.""" + with tempfile.TemporaryDirectory() as temp_dir: + db_path = Path(temp_dir) / "test_limit_large.db" + session_id = "limit_large_test" + session = SQLiteSession(session_id, db_path) + + model = FakeModel() + agent = Agent(name="test", model=model) + + # Build small conversation + model.set_next_output([get_text_message("Reply 1")]) + await run_agent_async(runner_method, agent, "Message 1", session=session) + + # Test with limit=100 (much larger than actual history) + model.set_next_output([get_text_message("Reply 2")]) + await run_agent_async( + runner_method, + agent, + "Message 2", + session=session, + run_config=RunConfig(session_settings=SessionSettings(limit=100)), + ) + + # Verify model received all available history + new message + last_input = model.last_turn_args["input"] + assert len(last_input) == 3 # 2 history + 1 new + assert last_input[0].get("content") == "Message 1" + # Assistant message has content as a list + assert last_input[1].get("content")[0]["text"] == "Reply 1" + assert last_input[2].get("content") == "Message 2" + + session.close() diff --git a/tests/utils/simple_session.py b/tests/utils/simple_session.py index b18d6fb92..e8a2cdece 100644 --- a/tests/utils/simple_session.py +++ b/tests/utils/simple_session.py @@ -2,6 +2,7 @@ from agents.items import TResponseInputItem from agents.memory.session import Session +from agents.memory.session_settings import SessionSettings class SimpleListSession(Session): @@ -9,6 +10,7 @@ class SimpleListSession(Session): def __init__(self, session_id: str = "test") -> None: self.session_id = session_id + self.session_settings = SessionSettings() self._items: list[TResponseInputItem] = [] async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: