Skip to content

Commit

Permalink
fix: improve state handling and JSON parsing
Browse files Browse the repository at this point in the history
- Add better JSON parsing with error handling and fallback
- Improve state handling using dataclasses with proper default values
- Add better extraction of clickable elements
- Add more error logging for diagnostics

Fixes:
- 'int not iterable' error in state handling
- JSON parsing errors in model responses
  • Loading branch information
ManojINaik committed Jan 8, 2025
1 parent dcb3914 commit 7cb5ec7
Show file tree
Hide file tree
Showing 2 changed files with 135 additions and 37 deletions.
107 changes: 92 additions & 15 deletions src/agent/custom_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,14 +130,14 @@ def update_step_info(
if step_info is None:
return

step_info.step_number += 1
step_info.step_number = getattr(step_info, 'step_number', 0) + 1
important_contents = model_output.current_state.important_contents
if (
important_contents
and "None" not in important_contents
and important_contents not in step_info.memory
and important_contents not in getattr(step_info, 'memory', '')
):
step_info.memory += important_contents + "\n"
step_info.memory = getattr(step_info, 'memory', '') + important_contents + "\n"

completed_contents = model_output.current_state.completed_contents
if completed_contents and "None" not in completed_contents:
Expand All @@ -148,14 +148,49 @@ async def get_next_action(self, input_messages: list[BaseMessage]) -> AgentOutpu
"""Get next action from LLM based on current state"""

ret = self.llm.invoke(input_messages)
parsed_json = json.loads(ret.content.replace("```json", "").replace("```", ""))
parsed: AgentOutput = self.AgentOutput(**parsed_json)
# cut the number of actions to max_actions_per_step
parsed.action = parsed.action[: self.max_actions_per_step]
self._log_response(parsed)
self.n_steps += 1

return parsed
content = ret.content

# Clean up the content to ensure valid JSON
if "```json" in content:
# Extract JSON from code block
start = content.find("```json") + 7
end = content.find("```", start)
if end == -1:
end = len(content)
content = content[start:end]
else:
# Try to find JSON object
start = content.find("{")
end = content.rfind("}") + 1
if start >= 0 and end > start:
content = content[start:end]

# Clean up any remaining whitespace or newlines
content = content.strip()

try:
parsed_json = json.loads(content)
parsed: AgentOutput = self.AgentOutput(**parsed_json)
# cut the number of actions to max_actions_per_step
parsed.action = parsed.action[: self.max_actions_per_step]
self._log_response(parsed)
self.n_steps += 1
return parsed
except json.JSONDecodeError as e:
logger.error(f"Failed to parse JSON: {str(e)}")
logger.error(f"Content was: {content}")
# Create a default response
from .custom_views import CustomAgentBrain
return CustomAgentOutput(
current_state=CustomAgentBrain(
prev_action_evaluation="Failed - Error parsing response",
important_contents="None",
completed_contents="",
thought="Failed to parse the response. Will retry with a simpler action.",
summary="Retry with simpler action"
),
action=[{"go_to_url": {"url": "https://www.google.com"}}]
)

@time_execution_async("--step")
async def step(self, step_info: Optional[CustomAgentStepInfo] = None) -> None:
Expand All @@ -167,18 +202,59 @@ async def step(self, step_info: Optional[CustomAgentStepInfo] = None) -> None:

try:
state = await self.browser_context.get_state(use_vision=self.use_vision)
self.message_manager.add_state_message(state, self._last_result, step_info)
if state is None:
logger.error("Failed to get browser state")
return

# Create a new state object with default values
from dataclasses import dataclass, field
from typing import List, Optional

@dataclass
class ElementTree:
clickable_elements: List[str] = field(default_factory=list)

def clickable_elements_to_string(self, include_attributes=None):
return "\n".join(self.clickable_elements) if self.clickable_elements else ""

@dataclass
class BrowserState:
url: str = ""
tabs: List[str] = field(default_factory=list)
element_tree: ElementTree = field(default_factory=ElementTree)
screenshot: Optional[str] = None

browser_state = BrowserState()
browser_state.url = getattr(state, 'url', '')
browser_state.tabs = getattr(state, 'tabs', [])
browser_state.screenshot = getattr(state, 'screenshot', None)

# Extract clickable elements if available
if hasattr(state, 'element_tree') and hasattr(state.element_tree, 'clickable_elements'):
browser_state.element_tree.clickable_elements = state.element_tree.clickable_elements

self.message_manager.add_state_message(browser_state, self._last_result, step_info)
input_messages = self.message_manager.get_messages()
if not input_messages:
logger.error("Failed to get input messages")
return

model_output = await self.get_next_action(input_messages)
if model_output is None:
logger.error("Failed to get next action")
return

self.update_step_info(model_output, step_info)
logger.info(f"🧠 All Memory: {step_info.memory}")
logger.info(f"🧠 All Memory: {getattr(step_info, 'memory', '')}")
self._save_conversation(input_messages, model_output)
self.message_manager._remove_last_state_message() # we dont want the whole state in the chat history
self.message_manager.add_model_output(model_output)

result: list[ActionResult] = await self.controller.multi_act(
model_output.action, self.browser_context
)
if result is None:
result = []
self._last_result = result

if len(result) > 0 and result[-1].is_done:
Expand All @@ -187,14 +263,15 @@ async def step(self, step_info: Optional[CustomAgentStepInfo] = None) -> None:
self.consecutive_failures = 0

except Exception as e:
logger.error(f"Error in step: {str(e)}")
result = self._handle_step_error(e)
self._last_result = result

finally:
if not result:
return
for r in result:
if r.error:
if r and r.error:
self.telemetry.capture(
AgentStepErrorTelemetryEvent(
agent_id=self.agent_id,
Expand All @@ -219,7 +296,7 @@ async def run(self, max_steps: int = 100) -> AgentHistoryList:
step_info = CustomAgentStepInfo(
task=self.task,
add_infos=self.add_infos,
step_number=1,
step_number=0, # Start at 0 since update_step_info will increment
max_steps=max_steps,
memory="",
task_progress="",
Expand Down
65 changes: 43 additions & 22 deletions src/agent/custom_massage_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,26 +63,47 @@ def add_state_message(
) -> None:
"""Add browser state as human message"""

# if keep in memory, add to directly to history and add state without result
if result:
for r in result:
if r.include_in_memory:
if r.extracted_content:
msg = HumanMessage(content=str(r.extracted_content))
self._add_message_with_tokens(msg)
if r.error:
msg = HumanMessage(
content=str(r.error)[-self.max_error_length :]
)
self._add_message_with_tokens(msg)
result = None # if result in history, we dont want to add it again
try:
# if keep in memory, add to directly to history and add state without result
if result:
for r in result:
if r and r.include_in_memory:
if r.extracted_content:
msg = HumanMessage(content=str(r.extracted_content))
self._add_message_with_tokens(msg)
if r.error:
msg = HumanMessage(
content=str(r.error)[-self.max_error_length :]
)
self._add_message_with_tokens(msg)
result = None # if result in history, we dont want to add it again

# otherwise add state message and result to next message (which will not stay in memory)
state_message = CustomAgentMessagePrompt(
state,
result,
include_attributes=self.include_attributes,
max_error_length=self.max_error_length,
step_info=step_info,
).get_user_message()
self._add_message_with_tokens(state_message)
# Create state message with safe attribute access
state_message = CustomAgentMessagePrompt(
state,
result,
include_attributes=self.include_attributes,
max_error_length=self.max_error_length,
step_info=step_info,
).get_user_message()

if state_message and hasattr(state_message, 'content'):
if isinstance(state_message.content, str):
self._add_message_with_tokens(state_message)
elif isinstance(state_message.content, list):
# Handle multi-modal messages (text + image)
has_valid_content = False
for item in state_message.content:
if isinstance(item, dict):
if item.get('type') == 'text' and item.get('text'):
has_valid_content = True
elif item.get('type') == 'image_url' and item.get('image_url', {}).get('url'):
has_valid_content = True
if has_valid_content:
self._add_message_with_tokens(state_message)

except Exception as e:
logger.error(f"Error in add_state_message: {str(e)}")
# Create a basic message if state processing fails
msg = HumanMessage(content="Error processing browser state")
self._add_message_with_tokens(msg)

0 comments on commit 7cb5ec7

Please sign in to comment.