Skip to content

Commit

Permalink
feat: This commit adds pull request support to SCM generator so the g…
Browse files Browse the repository at this point in the history
…enerator

can create ArgoCD apps for PRs as well.

Fixes argoproj#466

Signed-off-by: Fardin Khanjani <[email protected]>
  • Loading branch information
fardin01 committed Feb 24, 2022
1 parent e900eab commit 03b148b
Show file tree
Hide file tree
Showing 9 changed files with 245 additions and 31 deletions.
8 changes: 8 additions & 0 deletions api/v1alpha1/applicationset_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,8 @@ type SCMProviderGeneratorGithub struct {
TokenRef *SecretRef `json:"tokenRef,omitempty"`
// Scan all branches instead of just the default branch.
AllBranches bool `json:"allBranches,omitempty"`
// Scan all pull requests
AllPullRequests bool `json:"allPullRequests,omitempty"`
}

// SCMProviderGeneratorGitlab defines a connection info specific to Gitlab.
Expand All @@ -328,6 +330,8 @@ type SCMProviderGeneratorGitlab struct {
TokenRef *SecretRef `json:"tokenRef,omitempty"`
// Scan all branches instead of just the default branch.
AllBranches bool `json:"allBranches,omitempty"`
// Scan all pull requests
AllPullRequests bool `json:"allPullRequests,omitempty"`
}

// SCMProviderGeneratorFilter is a single repository filter.
Expand All @@ -342,6 +346,10 @@ type SCMProviderGeneratorFilter struct {
LabelMatch *string `json:"labelMatch,omitempty"`
// A regex which must match the branch name.
BranchMatch *string `json:"branchMatch,omitempty"`
// A regex which must match the pull request tile.
PullRequestTitleMatch *string `json:"pullRequestTitleMatch,omitempty"`
// A regex which must match at least one pull request label.
PullRequestLabelMatch *string `json:"pullRequestLabelMatch,omitempty"`
}

// PullRequestGenerator defines a generator that scrapes a PullRequest API to find candidate pull requests.
Expand Down
4 changes: 2 additions & 2 deletions pkg/generators/scm_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ func (g *SCMProviderGenerator) GenerateParams(appSetGenerator *argoprojiov1alpha
if err != nil {
return nil, fmt.Errorf("error fetching Github token: %v", err)
}
provider, err = scm_provider.NewGithubProvider(ctx, providerConfig.Github.Organization, token, providerConfig.Github.API, providerConfig.Github.AllBranches)
provider, err = scm_provider.NewGithubProvider(ctx, providerConfig.Github.Organization, token, providerConfig.Github.API, providerConfig.Github.AllBranches, providerConfig.Github.AllPullRequests)
if err != nil {
return nil, fmt.Errorf("error initializing Github service: %v", err)
}
Expand All @@ -73,7 +73,7 @@ func (g *SCMProviderGenerator) GenerateParams(appSetGenerator *argoprojiov1alpha
if err != nil {
return nil, fmt.Errorf("error fetching Gitlab token: %v", err)
}
provider, err = scm_provider.NewGitlabProvider(ctx, providerConfig.Gitlab.Group, token, providerConfig.Gitlab.API, providerConfig.Gitlab.AllBranches, providerConfig.Gitlab.IncludeSubgroups)
provider, err = scm_provider.NewGitlabProvider(ctx, providerConfig.Gitlab.Group, token, providerConfig.Gitlab.API, providerConfig.Gitlab.AllBranches, providerConfig.Gitlab.IncludeSubgroups, providerConfig.Gitlab.AllPullRequests)
if err != nil {
return nil, fmt.Errorf("error initializing Gitlab service: %v", err)
}
Expand Down
69 changes: 63 additions & 6 deletions pkg/services/scm_provider/github.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,15 @@ import (
)

type GithubProvider struct {
client *github.Client
organization string
allBranches bool
client *github.Client
organization string
allBranches bool
allPullRequests bool
}

var _ SCMProviderService = &GithubProvider{}

func NewGithubProvider(ctx context.Context, organization string, token string, url string, allBranches bool) (*GithubProvider, error) {
func NewGithubProvider(ctx context.Context, organization string, token string, url string, allBranches bool, allPullRequests bool) (*GithubProvider, error) {
var ts oauth2.TokenSource
// Undocumented environment variable to set a default token, to be used in testing to dodge anonymous rate limits.
if token == "" {
Expand All @@ -40,7 +41,7 @@ func NewGithubProvider(ctx context.Context, organization string, token string, u
return nil, err
}
}
return &GithubProvider{client: client, organization: organization, allBranches: allBranches}, nil
return &GithubProvider{client: client, organization: organization, allBranches: allBranches, allPullRequests: allPullRequests}, nil
}

func (g *GithubProvider) GetBranches(ctx context.Context, repo *Repository) ([]*Repository, error) {
Expand All @@ -64,6 +65,32 @@ func (g *GithubProvider) GetBranches(ctx context.Context, repo *Repository) ([]*
return repos, nil
}

func (g *GithubProvider) GetPullRequests(ctx context.Context, repo *Repository) ([]*Repository, error) {
repos := []*Repository{}
pullRequests, err := g.listPullRequests(ctx, repo)
if err != nil {
return nil, fmt.Errorf("error listing pull requests for %s/%s: %v", repo.Organization, repo.Repository, err)
}

// go-github's PullRequest type does not have a GetLabel() function.
var labels []string
for _, pullRequest := range pullRequests {
for _, label := range pullRequest.Labels {
labels = append(labels, label.GetName())
}
repos = append(repos, &Repository{
Organization: repo.Organization,
Repository: repo.Repository,
URL: repo.URL,
Branch: pullRequest.GetTitle(),
SHA: pullRequest.GetHead().GetSHA(),
Labels: labels,
RepositoryId: repo.RepositoryId,
})
}
return repos, nil
}

func (g *GithubProvider) ListRepos(ctx context.Context, cloneProtocol string) ([]*Repository, error) {
opt := &github.RepositoryListByOrgOptions{
ListOptions: github.ListOptions{PerPage: 100},
Expand Down Expand Up @@ -104,7 +131,7 @@ func (g *GithubProvider) ListRepos(ctx context.Context, cloneProtocol string) ([

func (g *GithubProvider) RepoHasPath(ctx context.Context, repo *Repository, path string) (bool, error) {
_, _, resp, err := g.client.Repositories.GetContents(ctx, repo.Organization, repo.Repository, path, &github.RepositoryContentGetOptions{
Ref: repo.Branch,
Ref: repo.SHA,
})
// 404s are not an error here, just a normal false.
if resp != nil && resp.StatusCode == 404 {
Expand Down Expand Up @@ -153,3 +180,33 @@ func (g *GithubProvider) listBranches(ctx context.Context, repo *Repository) ([]
}
return branches, nil
}

func (g *GithubProvider) listPullRequests(ctx context.Context, repo *Repository) ([]github.PullRequest, error) {

if !g.allPullRequests {
return nil, nil
}

opt := &github.PullRequestListOptions{
ListOptions: github.ListOptions{PerPage: 100},
}

githubPullRequests := []github.PullRequest{}

for {
allPullRequests, resp, err := g.client.PullRequests.List(ctx, repo.Organization, repo.Repository, opt)
if err != nil {
return nil, err
}

for _, pr := range allPullRequests {
githubPullRequests = append(githubPullRequests, *pr)
}

if resp.NextPage == 0 {
break
}
opt.Page = resp.NextPage
}
return githubPullRequests, nil
}
18 changes: 12 additions & 6 deletions pkg/services/scm_provider/github_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,10 @@ func checkRateLimit(t *testing.T, err error) {

func TestGithubListRepos(t *testing.T) {
cases := []struct {
name, proto, url string
hasError, allBranches bool
branches []string
filters []v1alpha1.SCMProviderGeneratorFilter
name, proto, url string
hasError, allBranches, allPullRequests bool
branches []string
filters []v1alpha1.SCMProviderGeneratorFilter
}{
{
name: "blank protocol",
Expand Down Expand Up @@ -67,11 +67,17 @@ func TestGithubListRepos(t *testing.T) {
url: "[email protected]:argoproj/applicationset.git",
branches: []string{"master", "release-0.1.0"},
},
{
name: "all pull requests",
allPullRequests: true,
url: "[email protected]:argoproj/applicationset.git",
branches: []string{"pr-1", "pr-2"},
},
}

for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
provider, _ := NewGithubProvider(context.Background(), "argoproj", "", "", c.allBranches)
provider, _ := NewGithubProvider(context.Background(), "argoproj", "", "", c.allBranches, c.allPullRequests)
rawRepos, err := ListRepos(context.Background(), provider, c.filters, c.proto)
if c.hasError {
assert.Error(t, err)
Expand All @@ -98,7 +104,7 @@ func TestGithubListRepos(t *testing.T) {
}

func TestGithubHasPath(t *testing.T) {
host, _ := NewGithubProvider(context.Background(), "argoproj", "", "", false)
host, _ := NewGithubProvider(context.Background(), "argoproj", "", "", false, false)
repo := &Repository{
Organization: "argoproj",
Repository: "applicationset",
Expand Down
54 changes: 52 additions & 2 deletions pkg/services/scm_provider/gitlab.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@ type GitlabProvider struct {
organization string
allBranches bool
includeSubgroups bool
allPullRequests bool
}

var _ SCMProviderService = &GitlabProvider{}

func NewGitlabProvider(ctx context.Context, organization string, token string, url string, allBranches, includeSubgroups bool) (*GitlabProvider, error) {
func NewGitlabProvider(ctx context.Context, organization string, token string, url string, allBranches, includeSubgroups, allPullRequests bool) (*GitlabProvider, error) {
// Undocumented environment variable to set a default token, to be used in testing to dodge anonymous rate limits.
if token == "" {
token = os.Getenv("GITLAB_TOKEN")
Expand All @@ -36,7 +37,7 @@ func NewGitlabProvider(ctx context.Context, organization string, token string, u
return nil, err
}
}
return &GitlabProvider{client: client, organization: organization, allBranches: allBranches, includeSubgroups: includeSubgroups}, nil
return &GitlabProvider{client: client, organization: organization, allBranches: allBranches, includeSubgroups: includeSubgroups, allPullRequests: allPullRequests}, nil
}

func (g *GitlabProvider) GetBranches(ctx context.Context, repo *Repository) ([]*Repository, error) {
Expand All @@ -60,6 +61,28 @@ func (g *GitlabProvider) GetBranches(ctx context.Context, repo *Repository) ([]*
return repos, nil
}

func (g *GitlabProvider) GetPullRequests(ctx context.Context, repo *Repository) ([]*Repository, error) {
repos := []*Repository{}

pullRequests, err := g.listPullRequests(ctx, repo)
if err != nil {
return nil, err
}

for _, pullRequest := range pullRequests {
repos = append(repos, &Repository{
Organization: repo.Organization,
Repository: repo.Repository,
URL: repo.URL,
Branch: pullRequest.Title,
SHA: pullRequest.SHA,
Labels: pullRequest.Labels,
RepositoryId: repo.RepositoryId,
})
}
return repos, nil
}

func (g *GitlabProvider) ListRepos(ctx context.Context, cloneProtocol string) ([]*Repository, error) {
opt := &gitlab.ListGroupProjectsOptions{
ListOptions: gitlab.ListOptions{PerPage: 100},
Expand Down Expand Up @@ -149,3 +172,30 @@ func (g *GitlabProvider) listBranches(_ context.Context, repo *Repository) ([]gi
}
return branches, nil
}

func (g *GitlabProvider) listPullRequests(_ context.Context, repo *Repository) ([]gitlab.MergeRequest, error) {
if !g.allPullRequests {
return nil, nil
}

opt := &gitlab.ListProjectMergeRequestsOptions{
ListOptions: gitlab.ListOptions{PerPage: 100},
}

pullRequests := []gitlab.MergeRequest{}
for {
gitlabPullRequests, resp, err := g.client.MergeRequests.ListProjectMergeRequests(repo.RepositoryId, opt)
if err != nil {
return nil, err
}
for _, gitlabPullRequest := range gitlabPullRequests {
pullRequests = append(pullRequests, *gitlabPullRequest)
}

if resp.NextPage == 0 {
break
}
opt.Page = resp.NextPage
}
return pullRequests, nil
}
12 changes: 6 additions & 6 deletions pkg/services/scm_provider/gitlab_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@ import (

func TestGitlabListRepos(t *testing.T) {
cases := []struct {
name, proto, url string
hasError, allBranches, includeSubgroups bool
branches []string
filters []v1alpha1.SCMProviderGeneratorFilter
name, proto, url string
hasError, allBranches, includeSubgroups, allPullRequests bool
branches []string
filters []v1alpha1.SCMProviderGeneratorFilter
}{
{
name: "blank protocol",
Expand Down Expand Up @@ -45,7 +45,7 @@ func TestGitlabListRepos(t *testing.T) {

for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
provider, _ := NewGitlabProvider(context.Background(), "test-argocd-proton", "", "", c.allBranches, c.includeSubgroups)
provider, _ := NewGitlabProvider(context.Background(), "test-argocd-proton", "", "", c.allBranches, c.includeSubgroups, c.allPullRequests)
rawRepos, err := ListRepos(context.Background(), provider, c.filters, c.proto)
if c.hasError {
assert.NotNil(t, err)
Expand All @@ -72,7 +72,7 @@ func TestGitlabListRepos(t *testing.T) {
}

func TestGitlabHasPath(t *testing.T) {
host, _ := NewGitlabProvider(context.Background(), "test-argocd-proton", "", "", false, true)
host, _ := NewGitlabProvider(context.Background(), "test-argocd-proton", "", "", false, true, false)
repo := &Repository{
Organization: "test-argocd-proton",
Repository: "argocd",
Expand Down
20 changes: 19 additions & 1 deletion pkg/services/scm_provider/mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,25 @@ func (m *MockProvider) GetBranches(_ context.Context, repo *Repository) ([]*Repo
branchRepos = append(branchRepos, candidateRepo)
}
}

}
return branchRepos, nil
}

func (m *MockProvider) GetPullRequests(_ context.Context, repo *Repository) ([]*Repository, error) {
pullRequestRepos := []*Repository{}
for _, candidateRepo := range m.Repos {
if candidateRepo.Repository == repo.Repository {
found := false
for _, alreadySetRepo := range pullRequestRepos {
if alreadySetRepo.Branch == candidateRepo.Branch {
found = true
break
}
}
if !found {
pullRequestRepos = append(pullRequestRepos, candidateRepo)
}
}
}
return pullRequestRepos, nil
}
14 changes: 9 additions & 5 deletions pkg/services/scm_provider/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,18 @@ type SCMProviderService interface {
ListRepos(context.Context, string) ([]*Repository, error)
RepoHasPath(context.Context, *Repository, string) (bool, error)
GetBranches(context.Context, *Repository) ([]*Repository, error)
GetPullRequests(context.Context, *Repository) ([]*Repository, error)
}

// A compiled version of SCMProviderGeneratorFilter for performance.
type Filter struct {
RepositoryMatch *regexp.Regexp
PathsExist []string
LabelMatch *regexp.Regexp
BranchMatch *regexp.Regexp
FilterType FilterType
RepositoryMatch *regexp.Regexp
PathsExist []string
LabelMatch *regexp.Regexp
BranchMatch *regexp.Regexp
PullRequestTitleMatch *regexp.Regexp
PullRequestLabelMatch *regexp.Regexp
FilterType FilterType
}

// A convenience type for indicating where to apply a filter
Expand All @@ -39,4 +42,5 @@ const (
FilterTypeUndefined FilterType = iota
FilterTypeBranch
FilterTypeRepo
FilterTypePullRequest
)
Loading

0 comments on commit 03b148b

Please sign in to comment.