Skip to content

Commit

Permalink
Always authenticate app installation
Browse files Browse the repository at this point in the history
And make sure the type system checks correct
credential usage.
  • Loading branch information
elegaanz committed Sep 25, 2024
1 parent f43af1a commit 3bc8b56
Show file tree
Hide file tree
Showing 4 changed files with 116 additions and 49 deletions.
25 changes: 17 additions & 8 deletions src/github.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ pub async fn hook_server() {

let app = Router::new()
.route("/", get(index))
.route("/github-hook", post(github_hook))
.route("/github-hook", post(github_hook::<GitHub<AuthJwt>>))
.route("/force-review/:install/:sha", get(force))
.layer(tower_http::trace::TraceLayer::new_for_http())
.with_state(state);
Expand Down Expand Up @@ -95,6 +95,17 @@ async fn force(
"Invalid repository path"
})?;

let installation = Installation {
id: str::parse(&install).map_err(|_| "Invalid installation ID")?,
};
let api_client = api_client
.auth_installation(&installation)
.await
.map_err(|e| {
debug!("Failed to authenticate installation: {}", e);
"Failed to authenticate installation"
})?;

let pr = MinimalPullRequest { number: pr };
let full_pr = pr
.get_full(&api_client, repository.owner(), repository.name())
Expand All @@ -110,9 +121,7 @@ async fn force(
api_client,
HookPayload::CheckSuite(CheckSuitePayload {
action: CheckSuiteAction::Requested,
installation: Installation {
id: str::parse(&install).map_err(|_| "Invalid installation ID")?,
},
installation,
repository,
check_suite: CheckSuite {
head_sha: sha,
Expand All @@ -130,13 +139,13 @@ async fn force(
}

/// The route to handle GitHub hooks. Mounted on `/github-hook`.
async fn github_hook(
async fn github_hook<G: GitHubAuth>(
State(state): State<AppState>,
mut api_client: GitHub,
api_client: G,
payload: HookPayload,
) -> Result<(), WebError> {
debug!("GitHub hook was triggered");
api_client.auth_installation(&payload).await?;
let api_client = api_client.auth_installation(&payload).await?;
debug!("Successfully authenticated application");

let (head_sha, repository, pr, previous_check_run) = match payload {
Expand Down Expand Up @@ -190,7 +199,7 @@ async fn github_hook(
async fn inner(
state: AppState,
head_sha: String,
api_client: GitHub,
api_client: GitHub<AuthInstallation>,
repository: Repository,
previous_check_run: Option<CheckRun>,
pr: Option<PullRequest>,
Expand Down
124 changes: 88 additions & 36 deletions src/github/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,7 @@ use reqwest::{RequestBuilder, Response, StatusCode};
use serde::Deserialize;
use tracing::{debug, warn};

use self::{
check::{CheckRun, CheckRunId, CheckRunOutput},
hook::HookPayload,
};
use self::check::{CheckRun, CheckRunId, CheckRunOutput};

use super::AppState;

Expand Down Expand Up @@ -67,27 +64,95 @@ impl From<serde_json::Error> for ApiError {

type ApiResult<T> = Result<T, ApiError>;

/// Authentication for the GitHub API using a JWT token.
pub struct AuthJwt(String);

impl Display for AuthJwt {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.0.fmt(f)
}
}

/// Authentication for the GitHub API using an installation token, that
/// is scoped to a specific organization or set of repositories, but that
/// can do more than a [`AuthJwt`] token.
pub struct AuthInstallation(String);

impl Display for AuthInstallation {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.0.fmt(f)
}
}

/// A GitHub API client
pub struct GitHub {
jwt: String,
pub struct GitHub<A = AuthJwt> {
auth: A,
req: reqwest::Client,
}

impl GitHub {
impl<A: ToString> GitHub<A> {
fn get(&self, url: impl AsRef<str>) -> RequestBuilder {
self.with_headers(self.req.get(Self::url(url)))
}

fn patch(&self, url: impl AsRef<str>) -> RequestBuilder {
self.with_headers(self.req.patch(Self::url(url)))
}

fn post(&self, url: impl AsRef<str>) -> RequestBuilder {
self.with_headers(self.req.post(Self::url(url)))
}

fn with_headers(&self, req: RequestBuilder) -> RequestBuilder {
req.bearer_auth(self.auth.to_string())
.header("Accept", "application/vnd.github+json")
.header("X-GitHub-Api-Version", "2022-11-28")
.header("User-Agent", "Typst package check")
}

fn url<S: AsRef<str>>(path: S) -> String {
format!("https://api.github.com/{}", path.as_ref())
}
}

pub trait GitHubAuth {
async fn auth_installation(
self,
installation: &impl AsInstallation,
) -> ApiResult<GitHub<AuthInstallation>>;
}

impl GitHubAuth for GitHub<AuthJwt> {
#[tracing::instrument(skip_all)]
pub async fn auth_installation(&mut self, payload: &HookPayload) -> ApiResult<()> {
let installation = &payload.installation().id;
async fn auth_installation(
self,
installation: &impl AsInstallation,
) -> ApiResult<GitHub<AuthInstallation>> {
let installation_id = installation.id();
let installation_token: InstallationToken = self
.post(format!("app/installations/{installation}/access_tokens"))
.post(format!("app/installations/{installation_id}/access_tokens"))
.send()
.await?
.parse_json()
.await?;
self.jwt = installation_token.token;

Ok(())
Ok(GitHub {
req: self.req,
auth: AuthInstallation(installation_token.token),
})
}
}

impl GitHubAuth for GitHub<AuthInstallation> {
async fn auth_installation(
self,
_installation: &impl AsInstallation,
) -> ApiResult<GitHub<AuthInstallation>> {
Ok(self)
}
}

impl GitHub<AuthInstallation> {
#[tracing::instrument(skip(self))]
pub async fn create_check_run(
&self,
Expand Down Expand Up @@ -137,29 +202,6 @@ impl GitHub {
debug!("GitHub said: {}", res);
Ok(())
}

fn get(&self, url: impl AsRef<str>) -> RequestBuilder {
self.with_headers(self.req.get(Self::url(url)))
}

fn patch(&self, url: impl AsRef<str>) -> RequestBuilder {
self.with_headers(self.req.patch(Self::url(url)))
}

fn post(&self, url: impl AsRef<str>) -> RequestBuilder {
self.with_headers(self.req.post(Self::url(url)))
}

fn with_headers(&self, req: RequestBuilder) -> RequestBuilder {
req.bearer_auth(&self.jwt)
.header("Accept", "application/vnd.github+json")
.header("X-GitHub-Api-Version", "2022-11-28")
.header("User-Agent", "Typst package check")
}

fn url<S: AsRef<str>>(path: S) -> String {
format!("https://api.github.com/{}", path.as_ref())
}
}

#[async_trait::async_trait]
Expand All @@ -182,7 +224,7 @@ impl FromRequestParts<AppState> for GitHub {
};

Ok(Self {
jwt: token,
auth: AuthJwt(token),
req: reqwest::Client::new(),
})
}
Expand Down Expand Up @@ -248,6 +290,16 @@ pub struct Installation {
pub id: u64,
}

pub trait AsInstallation {
fn id(&self) -> u64;
}

impl AsInstallation for Installation {
fn id(&self) -> u64 {
self.id
}
}

#[derive(Deserialize)]
struct InstallationToken {
token: String,
Expand Down
8 changes: 7 additions & 1 deletion src/github/api/hook.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use crate::github::AppState;

use super::{
check::{CheckRun, CheckRunAction, CheckSuite, CheckSuiteAction},
Installation, Repository,
AsInstallation, Installation, Repository,
};

#[derive(Debug)]
Expand All @@ -28,6 +28,12 @@ impl HookPayload {
}
}

impl AsInstallation for HookPayload {
fn id(&self) -> u64 {
self.installation().id
}
}

/// Request extractor that reads and check a GitHub hook payload.
#[async_trait::async_trait]
impl FromRequest<AppState> for HookPayload {
Expand Down
8 changes: 4 additions & 4 deletions src/github/api/pr.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use serde::{Deserialize, Serialize};

use super::{ApiError, GitHub, JsonExt, OwnerId, RepoId};
use super::{ApiError, AuthInstallation, GitHub, JsonExt, OwnerId, RepoId};

#[derive(Clone, Debug, Deserialize)]
pub struct MinimalPullRequest {
Expand All @@ -10,7 +10,7 @@ pub struct MinimalPullRequest {
impl MinimalPullRequest {
pub async fn get_full(
&self,
api: &GitHub,
api: &GitHub<AuthInstallation>,
owner: OwnerId,
repo: RepoId,
) -> Result<PullRequest, ApiError> {
Expand Down Expand Up @@ -44,7 +44,7 @@ pub enum AnyPullRequest {
impl AnyPullRequest {
pub async fn get_full(
self,
api: &GitHub,
api: &GitHub<AuthInstallation>,
owner: OwnerId,
repo: RepoId,
) -> Result<PullRequest, ApiError> {
Expand All @@ -65,7 +65,7 @@ pub struct PullRequestUpdate {
pub title: String,
}

impl GitHub {
impl GitHub<AuthInstallation> {
pub async fn update_pull_request(
&self,
owner: OwnerId,
Expand Down

0 comments on commit 3bc8b56

Please sign in to comment.