Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

simplify result instructions #387

Merged
merged 2 commits into from
Nov 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/controlflow/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@

from controlflow.utilities.general import unwrap

CONTROLFLOW_ENV_FILE = os.getenv("CONTROLFLOW_ENV_FILE", "~/.controlflow/.env")
CONTROLFLOW_ENV_FILE = os.path.expanduser(
os.path.expandvars(os.getenv("CONTROLFLOW_ENV_FILE", "~/.controlflow/.env"))
)


class ControlFlowSettings(BaseSettings):
Expand Down
16 changes: 9 additions & 7 deletions src/controlflow/tasks/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,7 +673,8 @@ def get_success_tool(self) -> Tool:
instructions.append(
unwrap(
f"""
Use this tool to mark the task as successful and provide a result. The result schema is: {result_schema}
Use this tool to mark the task as successful and provide a
result. The result schema is: {result_schema}
"""
)
)
Expand All @@ -696,8 +697,9 @@ def succeed(**kwargs) -> str:
instructions.append(
unwrap(
f"""
Use this tool to mark the task as successful and provide a result with the `task_result` kwarg.
The `task_result` schema is: {{"task_result": {result_schema}}}
Use this tool to mark the task as successful and provide a
`result` value. The `result` value has the following schema:
{result_schema}.
"""
)
)
Expand All @@ -709,18 +711,18 @@ def succeed(**kwargs) -> str:
include_return_description=False,
metadata=metadata,
)
def succeed(task_result: result_schema) -> str: # type: ignore
def succeed(result: result_schema) -> str: # type: ignore
if self.is_successful():
raise ValueError(
f"{self.friendly_name()} is already marked successful."
)
if options:
if task_result not in options:
if result not in options:
raise ValueError(
f"Invalid option. Please choose one of {options}"
)
task_result = options[task_result]
self.mark_successful(result=task_result)
result = options[result]
self.mark_successful(result=result)
return f"{self.friendly_name()} marked successful."

return succeed
Expand Down
12 changes: 6 additions & 6 deletions tests/tasks/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,27 +485,27 @@ class TestSuccessTool:
def test_success_tool(self):
task = Task("choose 5", result_type=int)
tool = task.get_success_tool()
tool.run(input=dict(task_result=5))
tool.run(input=dict(result=5))
assert task.is_successful()
assert task.result == 5

def test_success_tool_with_list_of_options(self):
task = Task('choose "good"', result_type=["bad", "good", "medium"])
tool = task.get_success_tool()
tool.run(input=dict(task_result=1))
tool.run(input=dict(result=1))
assert task.is_successful()
assert task.result == "good"

def test_success_tool_with_list_of_options_requires_int(self):
task = Task('choose "good"', result_type=["bad", "good", "medium"])
tool = task.get_success_tool()
with pytest.raises(ValueError):
tool.run(input=dict(task_result="good"))
tool.run(input=dict(result="good"))

def test_tuple_of_ints_result(self):
task = Task("choose 5", result_type=(4, 5, 6))
tool = task.get_success_tool()
tool.run(input=dict(task_result=1))
tool.run(input=dict(result=1))
assert task.result == 5

def test_tuple_of_pydantic_models_result(self):
Expand All @@ -518,7 +518,7 @@ class Person(BaseModel):
result_type=(Person(name="Alice", age=30), Person(name="Bob", age=35)),
)
tool = task.get_success_tool()
tool.run(input=dict(task_result=1))
tool.run(input=dict(result=1))
assert task.result == Person(name="Bob", age=35)
assert isinstance(task.result, Person)

Expand Down Expand Up @@ -604,7 +604,7 @@ def test_invalid_completion_tool(self):
def test_manual_success_tool(self):
task = Task(objective="Test task", completion_tools=[], result_type=int)
success_tool = task.get_success_tool()
success_tool.run(input=dict(task_result=5))
success_tool.run(input=dict(result=5))
assert task.is_successful()
assert task.result == 5

Expand Down
2 changes: 1 addition & 1 deletion tests/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def task(self, default_fake_llm):
tool_calls=[
{
"name": "mark_task_12345_successful",
"args": {"task_result": "Hello!"},
"args": {"result": "Hello!"},
"id": "call_ZEPdV8mCgeBe5UHjKzm6e3pe",
"type": "tool_call",
}
Expand Down
4 changes: 2 additions & 2 deletions tests/utilities/test_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def test_record_task_events(default_fake_llm):
tool_calls=[
{
"name": "mark_task_12345_successful",
"args": {"task_result": "Hello!"},
"args": {"result": "Hello!"},
"id": "call_ZEPdV8mCgeBe5UHjKzm6e3pe",
"type": "tool_call",
}
Expand All @@ -39,7 +39,7 @@ def test_record_task_events(default_fake_llm):
assert events[3].event == "tool-result"
assert events[3].tool_result.tool_call == {
"name": "mark_task_12345_successful",
"args": {"task_result": "Hello!"},
"args": {"result": "Hello!"},
"id": "call_ZEPdV8mCgeBe5UHjKzm6e3pe",
"type": "tool_call",
}
Expand Down