Skip to content

Commit

Permalink
Update task result references
Browse files Browse the repository at this point in the history
  • Loading branch information
jlowin committed Nov 14, 2024
1 parent ea11324 commit fa18e8a
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 9 deletions.
12 changes: 6 additions & 6 deletions tests/tasks/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,27 +485,27 @@ class TestSuccessTool:
def test_success_tool(self):
task = Task("choose 5", result_type=int)
tool = task.get_success_tool()
tool.run(input=dict(task_result=5))
tool.run(input=dict(result=5))
assert task.is_successful()
assert task.result == 5

def test_success_tool_with_list_of_options(self):
task = Task('choose "good"', result_type=["bad", "good", "medium"])
tool = task.get_success_tool()
tool.run(input=dict(task_result=1))
tool.run(input=dict(result=1))
assert task.is_successful()
assert task.result == "good"

def test_success_tool_with_list_of_options_requires_int(self):
task = Task('choose "good"', result_type=["bad", "good", "medium"])
tool = task.get_success_tool()
with pytest.raises(ValueError):
tool.run(input=dict(task_result="good"))
tool.run(input=dict(result="good"))

def test_tuple_of_ints_result(self):
task = Task("choose 5", result_type=(4, 5, 6))
tool = task.get_success_tool()
tool.run(input=dict(task_result=1))
tool.run(input=dict(result=1))
assert task.result == 5

def test_tuple_of_pydantic_models_result(self):
Expand All @@ -518,7 +518,7 @@ class Person(BaseModel):
result_type=(Person(name="Alice", age=30), Person(name="Bob", age=35)),
)
tool = task.get_success_tool()
tool.run(input=dict(task_result=1))
tool.run(input=dict(result=1))
assert task.result == Person(name="Bob", age=35)
assert isinstance(task.result, Person)

Expand Down Expand Up @@ -604,7 +604,7 @@ def test_invalid_completion_tool(self):
def test_manual_success_tool(self):
task = Task(objective="Test task", completion_tools=[], result_type=int)
success_tool = task.get_success_tool()
success_tool.run(input=dict(task_result=5))
success_tool.run(input=dict(result=5))
assert task.is_successful()
assert task.result == 5

Expand Down
2 changes: 1 addition & 1 deletion tests/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def task(self, default_fake_llm):
tool_calls=[
{
"name": "mark_task_12345_successful",
"args": {"task_result": "Hello!"},
"args": {"result": "Hello!"},
"id": "call_ZEPdV8mCgeBe5UHjKzm6e3pe",
"type": "tool_call",
}
Expand Down
4 changes: 2 additions & 2 deletions tests/utilities/test_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def test_record_task_events(default_fake_llm):
tool_calls=[
{
"name": "mark_task_12345_successful",
"args": {"task_result": "Hello!"},
"args": {"result": "Hello!"},
"id": "call_ZEPdV8mCgeBe5UHjKzm6e3pe",
"type": "tool_call",
}
Expand All @@ -39,7 +39,7 @@ def test_record_task_events(default_fake_llm):
assert events[3].event == "tool-result"
assert events[3].tool_result.tool_call == {
"name": "mark_task_12345_successful",
"args": {"task_result": "Hello!"},
"args": {"result": "Hello!"},
"id": "call_ZEPdV8mCgeBe5UHjKzm6e3pe",
"type": "tool_call",
}
Expand Down

0 comments on commit fa18e8a

Please sign in to comment.