Skip to content

Commit

Permalink
Tr/poetry fix (#215)
Browse files Browse the repository at this point in the history
* Add 'cover-agent-full-repo' script to pyproject.toml and update README usage instructions

* Add compatibility for Python versions before 3.11 by using typing_extensions

* Add '--diff-coverage' and '--branch' arguments to utils.py and update README test command

* Fix coverage calculation in failure message to use test_validator's current coverage

* Add retry logic to AICaller with tenacity for improved error handling

* Add error handling for CoverAgent execution in main_full_repo.py

* Add tenacity dependency to pyproject.toml for retry logic implementation

* Update max_allowed_runtime_seconds and max_tokens settings, adjust test file print format

* update lock

* tests
  • Loading branch information
mrT23 authored Nov 13, 2024
1 parent 91b81b1 commit 6aefef8
Show file tree
Hide file tree
Showing 13 changed files with 160 additions and 95 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ poetry run cover-agent-full-repo \
--project-language="python" \
--project-root="<path_to_your_repo>" \
--code-coverage-report-path="<path_to_your_repo>/coverage.xml" \
--test-command="coverage run -m pytest <relative_path_to_unittest_folder> --cov=<path_to_your_repo> --cov-report=xml --cov-report=term --log-cli-level=INFO --timeout=30" \
--test-command="coverage run -m pytest <relative_path_to_unittest_folder> --cov=<path_to_your_repo> --cov-report=xml --cov-report=term --log-cli-level=INFO" \
--model=bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0
```

Expand Down
37 changes: 33 additions & 4 deletions cover_agent/AICaller.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,31 @@
import time

import litellm
from functools import wraps
from wandb.sdk.data_types.trace_tree import Trace
from tenacity import retry, retry_if_exception_type, retry_if_not_exception_type, stop_after_attempt, wait_fixed
MODEL_RETRIES = 3


def conditional_retry(func):
@wraps(func)
def wrapper(self, *args, **kwargs):
if not self.enable_retry:
return func(self, *args, **kwargs)

@retry(
stop=stop_after_attempt(MODEL_RETRIES),
wait=wait_fixed(1)
)
def retry_wrapper():
return func(self, *args, **kwargs)

return retry_wrapper()

return wrapper

class AICaller:
def __init__(self, model: str, api_base: str = ""):
def __init__(self, model: str, api_base: str = "", enable_retry=True):
"""
Initializes an instance of the AICaller class.
Expand All @@ -17,7 +37,9 @@ def __init__(self, model: str, api_base: str = ""):
"""
self.model = model
self.api_base = api_base
self.enable_retry = enable_retry

@conditional_retry # You can access self.enable_retry here
def call_model(self, prompt: dict, max_tokens=4096, stream=True):
"""
Call the language model with the provided prompt and retrieve the response.
Expand Down Expand Up @@ -73,7 +95,11 @@ def call_model(self, prompt: dict, max_tokens=4096, stream=True):
):
completion_params["api_base"] = self.api_base

response = litellm.completion(**completion_params)
try:
response = litellm.completion(**completion_params)
except Exception as e:
print(f"Error calling LLM model: {e}")
raise e

if stream:
chunks = []
Expand All @@ -85,11 +111,14 @@ def call_model(self, prompt: dict, max_tokens=4096, stream=True):
time.sleep(
0.01
) # Optional: Delay to simulate more 'natural' response pacing

except Exception as e:
print(f"Error during streaming: {e}")
print(f"Error calling LLM model during streaming: {e}")
if self.enable_retry:
raise e
model_response = litellm.stream_chunk_builder(chunks, messages=messages)
print("\n")
# Build the final response from the streamed chunks
model_response = litellm.stream_chunk_builder(chunks, messages=messages)
content = model_response["choices"][0]["message"]["content"]
usage = model_response["usage"]
prompt_tokens = int(usage["prompt_tokens"])
Expand Down
2 changes: 1 addition & 1 deletion cover_agent/CoverAgent.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def run(self):
if self.args.diff_coverage:
failure_message = f"Reached maximum iteration limit without achieving desired diff coverage. Current Coverage: {round(self.test_validator.current_coverage * 100, 2)}%"
else:
failure_message = f"Reached maximum iteration limit without achieving desired coverage. Current Coverage: {round(self.test_gen.current_coverage * 100, 2)}%"
failure_message = f"Reached maximum iteration limit without achieving desired coverage. Current Coverage: {round(self.test_validator.current_coverage * 100, 2)}%"
if self.args.strict_coverage:
# User requested strict coverage (similar to "--cov-fail-under in pytest-cov"). Fail with exist code 2.
self.logger.error(failure_message)
Expand Down
2 changes: 1 addition & 1 deletion cover_agent/Runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def run_command(command, cwd=None):
# Get the current time before running the test command, in milliseconds
command_start_time = int(round(time.time() * 1000))

max_allowed_runtime_seconds = get_settings().get("tests.max_allowed_runtime_seconds", 60)
max_allowed_runtime_seconds = get_settings().get("tests.max_allowed_runtime_seconds", 30)
# Ensure the command is executed with shell=True for string commands
try:
result = subprocess.run(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,10 @@

from enum import Enum, IntEnum, IntFlag
from typing import Dict, List, Literal, Union
from typing import NotRequired, TypedDict
try:
from typing import NotRequired, TypedDict
except ImportError: # before Python 3.11
from typing_extensions import NotRequired, TypedDict

URI = str
DocumentUri = str
Expand Down
5 changes: 4 additions & 1 deletion cover_agent/lsp_logic/multilspy/multilspy_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
from __future__ import annotations

from enum import IntEnum, Enum
from typing import NotRequired, TypedDict, List, Dict, Union
try:
from typing import NotRequired, TypedDict, List, Dict, Union
except ImportError: # before Python 3.11
from typing_extensions import NotRequired, TypedDict, List, Dict, Union

URI = str
DocumentUri = str
Expand Down
2 changes: 1 addition & 1 deletion cover_agent/lsp_logic/utils/utils_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ async def analyze_context(test_file, context_files, args, ai_caller):
context_files_include = [f for f in context_files if f != file]

if source_file:
print(f"Test file: `{test_file}` is a unit test file for source file: `{source_file}`")
print(f"Test file: `{test_file}`,\nis a unit test file for source file: `{source_file}`")
else:
print(f"Test file: `{test_file}` is not a unit test file")
except Exception as e:
Expand Down
20 changes: 12 additions & 8 deletions cover_agent/main_full_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,18 @@ async def run():
source_file, context_files_include = await analyze_context(test_file, context_files, args, ai_caller)

if source_file:
# Run the CoverAgent for the test file
args_copy = copy.deepcopy(args)
args_copy.source_file_path = source_file
args_copy.test_command_dir = args.project_root
args_copy.test_file_path = test_file
args_copy.included_files = context_files_include
agent = CoverAgent(args_copy)
agent.run()
try:
# Run the CoverAgent for the test file
args_copy = copy.deepcopy(args)
args_copy.source_file_path = source_file
args_copy.test_command_dir = args.project_root
args_copy.test_file_path = test_file
args_copy.included_files = context_files_include
agent = CoverAgent(args_copy)
agent.run()
except Exception as e:
print(f"Error running CoverAgent for test file '{test_file}': {e}")
pass


def main():
Expand Down
2 changes: 1 addition & 1 deletion cover_agent/settings/configuration.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[include_files]
limit_tokens=true
max_tokens=16000
max_tokens=20000

[tests]
max_allowed_runtime_seconds=30
10 changes: 10 additions & 0 deletions cover_agent/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,16 @@ def parse_args_full_repo():
default=100,
help="The desired coverage percentage. Default: %(default)s.",
)
parser.add_argument(
"--diff-coverage",
action="store_true",
default=False,
)
parser.add_argument(
"--branch",
type=str,
default="main",
)
return parser.parse_args()


Expand Down
Loading

0 comments on commit 6aefef8

Please sign in to comment.