mirror of
https://github.com/OpenHands/OpenHands.git
synced 2025-12-26 05:48:36 +08:00
Co-authored-by: openhands <openhands@all-hands.dev> Co-authored-by: sp.wack <83104063+amanape@users.noreply.github.com>
96 lines
3.0 KiB
Python
96 lines
3.0 KiB
Python
from __future__ import annotations
|
|
|
|
from dataclasses import dataclass
|
|
from typing import Generic, TypeVar
|
|
|
|
T = TypeVar(
|
|
'T', int, float
|
|
) # Type for the value (int for iterations, float for budget)
|
|
|
|
|
|
@dataclass
|
|
class ControlFlag(Generic[T]):
|
|
"""Base class for control flags that manage limits and state transitions."""
|
|
|
|
limit_increase_amount: T
|
|
current_value: T
|
|
max_value: T
|
|
headless_mode: bool = False
|
|
_hit_limit: bool = False
|
|
|
|
def reached_limit(self) -> bool:
|
|
"""Check if the limit has been reached.
|
|
|
|
Returns:
|
|
bool: True if the limit has been reached, False otherwise.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
def increase_limit(self, headless_mode: bool) -> None:
|
|
"""Expand the limit when needed."""
|
|
raise NotImplementedError
|
|
|
|
def step(self):
|
|
"""Determine the next state based on the current state and mode.
|
|
|
|
Returns:
|
|
ControlFlagState: The next state.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
|
|
@dataclass
|
|
class IterationControlFlag(ControlFlag[int]):
|
|
"""Control flag for managing iteration limits."""
|
|
|
|
def reached_limit(self) -> bool:
|
|
"""Check if the iteration limit has been reached."""
|
|
self._hit_limit = self.current_value >= self.max_value
|
|
return self._hit_limit
|
|
|
|
def increase_limit(self, headless_mode: bool) -> None:
|
|
"""Expand the iteration limit by adding the initial value."""
|
|
if not headless_mode and self._hit_limit:
|
|
self.max_value += self.limit_increase_amount
|
|
self._hit_limit = False
|
|
|
|
def step(self):
|
|
if self.reached_limit():
|
|
raise RuntimeError(
|
|
f'Agent reached maximum iteration. '
|
|
f'Current iteration: {self.current_value}, max iteration: {self.max_value}'
|
|
)
|
|
|
|
# Increment the current value
|
|
self.current_value += 1
|
|
|
|
|
|
@dataclass
|
|
class BudgetControlFlag(ControlFlag[float]):
|
|
"""Control flag for managing budget limits."""
|
|
|
|
def reached_limit(self) -> bool:
|
|
"""Check if the budget limit has been reached."""
|
|
self._hit_limit = self.current_value >= self.max_value
|
|
return self._hit_limit
|
|
|
|
def increase_limit(self, headless_mode) -> None:
|
|
"""Expand the budget limit by adding the initial value to the current value."""
|
|
if self._hit_limit:
|
|
self.max_value = self.current_value + self.limit_increase_amount
|
|
self._hit_limit = False
|
|
|
|
def step(self):
|
|
"""Check if we've reached the limit and update state accordingly.
|
|
|
|
Note: Unlike IterationControlFlag, this doesn't increment the value
|
|
as the budget is updated externally.
|
|
"""
|
|
if self.reached_limit():
|
|
current_str = f'{self.current_value:.2f}'
|
|
max_str = f'{self.max_value:.2f}'
|
|
raise RuntimeError(
|
|
f'Agent reached maximum budget for conversation.'
|
|
f'Current budget: {current_str}, max budget: {max_str}'
|
|
)
|