Robert Brennan f7e0c6cd06
Separate agent controller and server via EventStream (#1538)
* move towards event stream

* refactor agent state changes

* move agent state logic

* fix callbacks

* break on finish

* closer to working

* change frontend to accomodate new flow

* handle start action

* fix locked stream

* revert message

* logspam

* no async on close

* get rid of agent_task

* fix up closing

* better asyncio handling

* sleep to give back control

* fix key

* logspam

* update frontend agent state actions

* fix pause and cancel

* delint

* fix map

* delint

* wait for agent to finish

* fix unit test

* event stream enums

* fix merge issues

* fix lint

* fix test

* fix test

* add user message action

* add user message action

* fix up user messages

* fix main.py flow

* refactor message waiting

* lint

* fix test

* fix test
2024-05-05 19:20:01 +00:00

61 lines
2.0 KiB
Python

from .agent import AgentStateChangedObservation
from .browse import BrowserOutputObservation
from .commands import CmdOutputObservation, IPythonRunCellObservation
from .delegate import AgentDelegateObservation
from .empty import NullObservation
from .error import AgentErrorObservation
from .files import FileReadObservation, FileWriteObservation
from .message import AgentMessageObservation, UserMessageObservation
from .observation import Observation
from .recall import AgentRecallObservation
observations = (
CmdOutputObservation,
BrowserOutputObservation,
FileReadObservation,
FileWriteObservation,
UserMessageObservation,
AgentMessageObservation,
AgentRecallObservation,
AgentDelegateObservation,
AgentErrorObservation,
AgentStateChangedObservation,
)
OBSERVATION_TYPE_TO_CLASS = {
observation_class.observation: observation_class # type: ignore[attr-defined]
for observation_class in observations
}
def observation_from_dict(observation: dict) -> Observation:
observation = observation.copy()
if 'observation' not in observation:
raise KeyError(f"'observation' key is not found in {observation=}")
observation_class = OBSERVATION_TYPE_TO_CLASS.get(observation['observation'])
if observation_class is None:
raise KeyError(
f"'{observation['observation']=}' is not defined. Available observations: {OBSERVATION_TYPE_TO_CLASS.keys()}"
)
observation.pop('observation')
observation.pop('message', None)
content = observation.pop('content', '')
extras = observation.pop('extras', {})
return observation_class(content=content, **extras)
__all__ = [
'Observation',
'NullObservation',
'CmdOutputObservation',
'IPythonRunCellObservation',
'BrowserOutputObservation',
'FileReadObservation',
'FileWriteObservation',
'UserMessageObservation',
'AgentMessageObservation',
'AgentRecallObservation',
'AgentErrorObservation',
'AgentStateChangedObservation',
]