diff --git a/src/git_bob/_ai_github_utilities.py b/src/git_bob/_ai_github_utilities.py index 0d5296e..7710c5c 100644 --- a/src/git_bob/_ai_github_utilities.py +++ b/src/git_bob/_ai_github_utilities.py @@ -403,7 +403,7 @@ def solve_github_issue(repository, issue, llm_model, prompt_function, base_branc discussion = modify_discussion(Config.git_utilities.get_conversation_on_issue(repository, issue)) print("Discussion:", discussion) - all_files = "* " + "\n* ".join(Config.git_utilities.list_repository_files(repository)) + all_files = "* " + "\n* ".join(Config.git_utilities.list_repository_files(repository, branch_name=base_branch)) modifications = prompt_function(f""" Given a list of files in the repository {repository} and a github issues description (# {issue}), determine which files need to be modified, renamed or deleted to solve the issue. diff --git a/src/git_bob/_github_utilities.py b/src/git_bob/_github_utilities.py index c396a40..4dbd16f 100644 --- a/src/git_bob/_github_utilities.py +++ b/src/git_bob/_github_utilities.py @@ -246,7 +246,7 @@ def get_issue_details(repository: str, issue: int) -> str: return content -def list_repository_files(repository: str) -> list: +def list_repository_files(repository: str, branch_name: str = "main") -> list: """ List all files in a given GitHub repository. @@ -257,6 +257,8 @@ def list_repository_files(repository: str) -> list: ---------- repository : str The full name of the GitHub repository (e.g., "username/repo-name"). + branch_name : str, optional + The name of the branch or tag (default is 'main'). Returns ------- @@ -269,7 +271,7 @@ def list_repository_files(repository: str) -> list: repo = get_repository_handle(repository) # Get all contents of the repository - contents = repo.get_contents("") + contents = repo.get_contents("", ref=branch_name) # List to store all file paths all_files = [] @@ -296,7 +298,7 @@ def get_repository_file_contents(repository: str, branch_name, file_paths: list) repository : str The full name of the GitHub repository (e.g., "username/repo-name"). branch_name : str, optional - The name of the branch or tag (default is 'main'). + The name of the branch or tag. file_paths : list A list of file paths within the repository to retrieve the contents of. @@ -305,7 +307,7 @@ def get_repository_file_contents(repository: str, branch_name, file_paths: list) dict A dictionary where keys are file paths and values are the contents of the files. """ - Log().log(f"-> get_repository_file_contents({repository}, {file_paths})") + Log().log(f"-> get_repository_file_contents({repository}, {branch_name}, {file_paths})") # Dictionary to store file contents file_contents = {} diff --git a/src/git_bob/_gitlab_utilities.py b/src/git_bob/_gitlab_utilities.py index f59e914..b139567 100644 --- a/src/git_bob/_gitlab_utilities.py +++ b/src/git_bob/_gitlab_utilities.py @@ -186,7 +186,7 @@ def get_issue_details(repository: str, issue: int) -> str: return content -def list_repository_files(repository: str): +def list_repository_files(repository: str, branch_name: str = "main") -> list: """ List all files in the specified GitLab repository branch. @@ -194,20 +194,22 @@ def list_repository_files(repository: str): ---------- repository : str The full name of the GitLab project (e.g., "username/repo-name"). + branch_name : str, optional + The name of the branch or tag (default is 'main'). Returns ------- list A list of file paths in the repository. """ - Log().log(f"-> list_repository_files({repository})") + Log().log(f"-> list_repository_files({repository}, {branch_name})") repo = get_repository_handle(repository) files = [] path_stack = [''] while path_stack: path = path_stack.pop() - tree = repo.repository_tree(path=path) + tree = repo.repository_tree(path=path, ref=branch_name) for item in tree: if item['type'] == 'blob': files.append(item['path']) @@ -234,7 +236,7 @@ def get_repository_file_contents(repository:str, branch_name, file_paths: list): str The content of the file as a string. """ - Log().log(f"-> get_repository_file_contents({repository}, {file_paths}, {branch_name})") + Log().log(f"-> get_repository_file_contents({repository}, {branch_name}, {file_paths})") project = get_repository_handle(repository) file_contents = {} diff --git a/tests/test_gitlab_utilities.py b/tests/test_gitlab_utilities.py index 872cc16..a39f28a 100644 --- a/tests/test_gitlab_utilities.py +++ b/tests/test_gitlab_utilities.py @@ -59,7 +59,7 @@ def test_list_repository_files(): Config.git_server_url = "https://gitlab.com" from git_bob._gitlab_utilities import list_repository_files - files = list(list_repository_files("haesleinhuepf/git-bob")) + files = list(list_repository_files("haesleinhuepf/git-bob", branch_name="main")) assert "README.md" in files assert "LICENSE" in files