Skip to content

Commit

Permalink
treewide: replace gorilla/mux with http.ServeMux
Browse files Browse the repository at this point in the history
Signed-off-by: Sumner Evans <[email protected]>
  • Loading branch information
sumnerevans committed Aug 23, 2024
1 parent 66c4178 commit dd17c7b
Show file tree
Hide file tree
Showing 11 changed files with 170 additions and 185 deletions.
17 changes: 8 additions & 9 deletions appservice/appservice.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ import (
"syscall"
"time"

"github.com/gorilla/mux"
"github.com/gorilla/websocket"
"github.com/rs/zerolog"
"golang.org/x/net/publicsuffix"
Expand All @@ -43,7 +42,7 @@ func Create() *AppService {
intents: make(map[id.UserID]*IntentAPI),
HTTPClient: &http.Client{Timeout: 180 * time.Second, Jar: jar},
StateStore: mautrix.NewMemoryStateStore().(StateStore),
Router: mux.NewRouter(),
Router: http.NewServeMux(),
UserAgent: mautrix.DefaultUserAgent,
txnIDC: NewTransactionIDCache(128),
Live: true,
Expand All @@ -61,12 +60,12 @@ func Create() *AppService {
DefaultHTTPRetries: 4,
}

as.Router.HandleFunc("/_matrix/app/v1/transactions/{txnID}", as.PutTransaction).Methods(http.MethodPut)
as.Router.HandleFunc("/_matrix/app/v1/rooms/{roomAlias}", as.GetRoom).Methods(http.MethodGet)
as.Router.HandleFunc("/_matrix/app/v1/users/{userID}", as.GetUser).Methods(http.MethodGet)
as.Router.HandleFunc("/_matrix/app/v1/ping", as.PostPing).Methods(http.MethodPost)
as.Router.HandleFunc("/_matrix/mau/live", as.GetLive).Methods(http.MethodGet)
as.Router.HandleFunc("/_matrix/mau/ready", as.GetReady).Methods(http.MethodGet)
as.Router.HandleFunc("PUT /_matrix/app/v1/transactions/{txnID}", as.PutTransaction)
as.Router.HandleFunc("GET /_matrix/app/v1/rooms/{roomAlias}", as.GetRoom)
as.Router.HandleFunc("GET /_matrix/app/v1/users/{userID}", as.GetUser)
as.Router.HandleFunc("POST /_matrix/app/v1/ping", as.PostPing)
as.Router.HandleFunc("GET /_matrix/mau/live", as.GetLive)
as.Router.HandleFunc("GET /_matrix/mau/ready", as.GetReady)

return as
}
Expand Down Expand Up @@ -160,7 +159,7 @@ type AppService struct {
QueryHandler QueryHandler
StateStore StateStore

Router *mux.Router
Router *http.ServeMux
UserAgent string
server *http.Server
HTTPClient *http.Client
Expand Down
12 changes: 3 additions & 9 deletions appservice/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ import (
"syscall"
"time"

"github.com/gorilla/mux"
"github.com/rs/zerolog"

"maunium.net/go/mautrix"
Expand Down Expand Up @@ -106,8 +105,7 @@ func (as *AppService) PutTransaction(w http.ResponseWriter, r *http.Request) {
return
}

vars := mux.Vars(r)
txnID := vars["txnID"]
txnID := r.PathValue("txnID")
if len(txnID) == 0 {
Error{
ErrorCode: ErrNoTransactionID,
Expand Down Expand Up @@ -263,9 +261,7 @@ func (as *AppService) GetRoom(w http.ResponseWriter, r *http.Request) {
return
}

vars := mux.Vars(r)
roomAlias := vars["roomAlias"]
ok := as.QueryHandler.QueryAlias(roomAlias)
ok := as.QueryHandler.QueryAlias(r.PathValue("roomAlias"))
if ok {
WriteBlankOK(w)
} else {
Expand All @@ -282,9 +278,7 @@ func (as *AppService) GetUser(w http.ResponseWriter, r *http.Request) {
return
}

vars := mux.Vars(r)
userID := id.UserID(vars["userID"])
ok := as.QueryHandler.QueryUser(userID)
ok := as.QueryHandler.QueryUser(id.UserID(r.PathValue("userID")))
if ok {
WriteBlankOK(w)
} else {
Expand Down
5 changes: 3 additions & 2 deletions bridgev2/matrix/connector.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"encoding/json"
"errors"
"fmt"
"net/http"
"net/url"
"os"
"regexp"
Expand All @@ -21,7 +22,6 @@ import (
"time"
"unsafe"

"github.com/gorilla/mux"
_ "github.com/lib/pq"
"github.com/rs/zerolog"
"go.mau.fi/util/dbutil"
Expand Down Expand Up @@ -216,7 +216,8 @@ func (br *Connector) GetPublicAddress() string {
return br.Config.AppService.PublicAddress
}

func (br *Connector) GetRouter() *mux.Router {
// TODO switch to http.ServeMux
func (br *Connector) GetRouter() *http.ServeMux {
if br.GetPublicAddress() != "" {
return br.AS.Router
}
Expand Down
89 changes: 47 additions & 42 deletions bridgev2/matrix/provisioning.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@ import (
"sync"
"time"

"github.com/gorilla/mux"
"github.com/rs/xid"
"github.com/rs/zerolog"
"github.com/rs/zerolog/hlog"
"go.mau.fi/util/exhttp"
"go.mau.fi/util/jsontime"
"go.mau.fi/util/requestlog"

Expand All @@ -37,7 +37,7 @@ type matrixAuthCacheEntry struct {
}

type ProvisioningAPI struct {
Router *mux.Router
Router *http.ServeMux

br *Connector
log zerolog.Logger
Expand Down Expand Up @@ -72,12 +72,12 @@ func (prov *ProvisioningAPI) GetUser(r *http.Request) *bridgev2.User {
return r.Context().Value(provisioningUserKey).(*bridgev2.User)
}

func (prov *ProvisioningAPI) GetRouter() *mux.Router {
func (prov *ProvisioningAPI) GetRouter() *http.ServeMux {
return prov.Router
}

type IProvisioningAPI interface {
GetRouter() *mux.Router
GetRouter() *http.ServeMux
GetUser(r *http.Request) *bridgev2.User
}

Expand All @@ -96,44 +96,38 @@ func (prov *ProvisioningAPI) Init() {
tp.Dialer.Timeout = 10 * time.Second
tp.Transport.ResponseHeaderTimeout = 10 * time.Second
tp.Transport.TLSHandshakeTimeout = 10 * time.Second
prov.Router = prov.br.AS.Router.PathPrefix(prov.br.Config.Provisioning.Prefix).Subrouter()
prov.Router.Use(hlog.NewHandler(prov.log))
prov.Router.Use(corsMiddleware)
prov.Router.Use(requestlog.AccessLogger(false))
prov.Router.Use(prov.AuthMiddleware)
prov.Router.Path("/v3/whoami").Methods(http.MethodGet, http.MethodOptions).HandlerFunc(prov.GetWhoami)
prov.Router.Path("/v3/login/flows").Methods(http.MethodGet, http.MethodOptions).HandlerFunc(prov.GetLoginFlows)
prov.Router.Path("/v3/login/start/{flowID}").Methods(http.MethodPost, http.MethodOptions).HandlerFunc(prov.PostLoginStart)
prov.Router.Path("/v3/login/step/{loginProcessID}/{stepID}/{stepType:user_input|cookies}").Methods(http.MethodPost, http.MethodOptions).HandlerFunc(prov.PostLoginSubmitInput)
prov.Router.Path("/v3/login/step/{loginProcessID}/{stepID}/{stepType:display_and_wait}").Methods(http.MethodPost, http.MethodOptions).HandlerFunc(prov.PostLoginWait)
prov.Router.Path("/v3/logout/{loginID}").Methods(http.MethodPost, http.MethodOptions).HandlerFunc(prov.PostLogout)
prov.Router.Path("/v3/logins").Methods(http.MethodGet, http.MethodOptions).HandlerFunc(prov.GetLogins)
prov.Router.Path("/v3/contacts").Methods(http.MethodGet, http.MethodOptions).HandlerFunc(prov.GetContactList)
prov.Router.Path("/v3/resolve_identifier/{identifier}").Methods(http.MethodGet, http.MethodOptions).HandlerFunc(prov.GetResolveIdentifier)
prov.Router.Path("/v3/create_dm/{identifier}").Methods(http.MethodPost, http.MethodOptions).HandlerFunc(prov.PostCreateDM)
prov.Router.Path("/v3/create_group").Methods(http.MethodPost, http.MethodOptions).HandlerFunc(prov.PostCreateGroup)

provRouter := http.NewServeMux()

provRouter.HandleFunc("GET /v3/whoami", prov.GetWhoami)
provRouter.HandleFunc("GET /v3/whoami/flows", prov.GetLoginFlows)

provRouter.HandleFunc("POST /v3/login/start/{flowID}", prov.PostLoginStart)
provRouter.HandleFunc("POST /v3/login/step/{loginProcessID}/{stepID}/{stepType}", prov.PostLogin)
provRouter.HandleFunc("POST /v3/logout/{loginID}", prov.PostLogout)
provRouter.HandleFunc("GET /v3/logins", prov.GetLogins)
provRouter.HandleFunc("GET /v3/contacts", prov.GetContactList)
provRouter.HandleFunc("GET /v3/resolve_identifier/{identifier}", prov.GetResolveIdentifier)
provRouter.HandleFunc("POST /v3/create_dm/{identifier}", prov.PostCreateDM)
provRouter.HandleFunc("POST /v3/create_group", prov.PostCreateGroup)

var provHandler http.Handler = prov.Router
provHandler = prov.AuthMiddleware(provHandler)
provHandler = requestlog.AccessLogger(false)(provHandler)
provHandler = exhttp.CORSMiddleware(provHandler)
provHandler = hlog.NewHandler(prov.log)(provHandler)
provHandler = http.StripPrefix(prov.br.Config.Provisioning.Prefix, provHandler)
prov.br.AS.Router.Handle(prov.br.Config.Provisioning.Prefix, provHandler)

if prov.br.Config.Provisioning.DebugEndpoints {
prov.log.Debug().Msg("Enabling debug API at /debug")
r := prov.br.AS.Router.PathPrefix("/debug").Subrouter()
r.Use(prov.AuthMiddleware)
r.PathPrefix("/pprof").Handler(http.DefaultServeMux)
debugRouter := http.NewServeMux()
// TODO do we need to strip prefix here?
debugRouter.Handle("/debug/pprof", http.StripPrefix("/debug/pprof", http.DefaultServeMux))
prov.br.AS.Router.Handle("/debug", prov.AuthMiddleware(debugRouter))
}
}

func corsMiddleware(handler http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "X-Requested-With, Content-Type, Authorization")
if r.Method == http.MethodOptions {
w.WriteHeader(http.StatusOK)
return
}
handler.ServeHTTP(w, r)
})
}

func jsonResponse(w http.ResponseWriter, status int, response any) {
w.Header().Add("Content-Type", "application/json")
w.WriteHeader(status)
Expand Down Expand Up @@ -221,7 +215,7 @@ func (prov *ProvisioningAPI) AuthMiddleware(h http.Handler) http.Handler {
// TODO handle user being nil?

ctx := context.WithValue(r.Context(), provisioningUserKey, user)
if loginID, ok := mux.Vars(r)["loginProcessID"]; ok {
if loginID := r.PathValue("loginProcessID"); loginID != "" {
prov.loginsLock.RLock()
login, ok := prov.logins[loginID]
prov.loginsLock.RUnlock()
Expand All @@ -236,7 +230,7 @@ func (prov *ProvisioningAPI) AuthMiddleware(h http.Handler) http.Handler {
login.Lock.Lock()
// This will only unlock after the handler runs
defer login.Lock.Unlock()
stepID := mux.Vars(r)["stepID"]
stepID := r.PathValue("stepID")
if login.NextStep.StepID != stepID {
zerolog.Ctx(r.Context()).Warn().
Str("request_step_id", stepID).
Expand All @@ -248,7 +242,7 @@ func (prov *ProvisioningAPI) AuthMiddleware(h http.Handler) http.Handler {
})
return
}
stepType := mux.Vars(r)["stepType"]
stepType := r.PathValue("stepType")
if login.NextStep.Type != bridgev2.LoginStepType(stepType) {
zerolog.Ctx(r.Context()).Warn().
Str("request_step_type", stepType).
Expand Down Expand Up @@ -352,7 +346,7 @@ func (prov *ProvisioningAPI) PostLoginStart(w http.ResponseWriter, r *http.Reque
login, err := prov.net.CreateLogin(
r.Context(),
prov.GetUser(r),
mux.Vars(r)["flowID"],
r.PathValue("flowID"),
)
if err != nil {
zerolog.Ctx(r.Context()).Err(err).Msg("Failed to create login process")
Expand Down Expand Up @@ -391,6 +385,17 @@ func (prov *ProvisioningAPI) handleCompleteStep(ctx context.Context, login *Prov
}, bridgev2.DeleteOpts{LogoutRemote: true})
}

func (prov *ProvisioningAPI) PostLogin(w http.ResponseWriter, r *http.Request) {
switch r.PathValue("stepType") {
case "user_input", "cookies":
prov.PostLoginSubmitInput(w, r)
case "display_and_wait":
prov.PostLoginWait(w, r)
default:
panic("Impossible state") // checked by the AuthMiddleware
}
}

func (prov *ProvisioningAPI) PostLoginSubmitInput(w http.ResponseWriter, r *http.Request) {
var params map[string]string
err := json.NewDecoder(r.Body).Decode(&params)
Expand Down Expand Up @@ -444,7 +449,7 @@ func (prov *ProvisioningAPI) PostLoginWait(w http.ResponseWriter, r *http.Reques

func (prov *ProvisioningAPI) PostLogout(w http.ResponseWriter, r *http.Request) {
user := prov.GetUser(r)
userLoginID := networkid.UserLoginID(mux.Vars(r)["loginID"])
userLoginID := networkid.UserLoginID(r.PathValue("loginID"))
if userLoginID == "all" {
for {
login := user.GetDefaultLogin()
Expand Down Expand Up @@ -548,7 +553,7 @@ func (prov *ProvisioningAPI) doResolveIdentifier(w http.ResponseWriter, r *http.
})
return
}
resp, err := api.ResolveIdentifier(r.Context(), mux.Vars(r)["identifier"], createChat)
resp, err := api.ResolveIdentifier(r.Context(), r.PathValue("identifier"), createChat)
if err != nil {
zerolog.Ctx(r.Context()).Err(err).Msg("Failed to resolve identifier")
respondMaybeCustomError(w, err, "Internal error resolving identifier")
Expand Down
11 changes: 4 additions & 7 deletions bridgev2/matrix/publicmedia.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@ import (
"net/http"
"time"

"github.com/gorilla/mux"

"maunium.net/go/mautrix/bridgev2"
"maunium.net/go/mautrix/id"
)
Expand All @@ -35,7 +33,7 @@ func (br *Connector) initPublicMedia() error {
return fmt.Errorf("public media hash length is negative")
}
br.pubMediaSigKey = []byte(br.Config.PublicMedia.SigningKey)
br.AS.Router.HandleFunc("/_mautrix/publicmedia/{server}/{mediaID}/{checksum}", br.servePublicMedia).Methods(http.MethodGet)
br.AS.Router.HandleFunc("GET /_mautrix/publicmedia/{server}/{mediaID}/{checksum}", br.servePublicMedia)
return nil
}

Expand Down Expand Up @@ -76,16 +74,15 @@ var proxyHeadersToCopy = []string{
}

func (br *Connector) servePublicMedia(w http.ResponseWriter, r *http.Request) {
vars := mux.Vars(r)
contentURI := id.ContentURI{
Homeserver: vars["server"],
FileID: vars["mediaID"],
Homeserver: r.PathValue("server"),
FileID: r.PathValue("mediaID"),
}
if !contentURI.IsValid() {
http.Error(w, "invalid content URI", http.StatusBadRequest)
return
}
checksum, err := base64.RawURLEncoding.DecodeString(vars["checksum"])
checksum, err := base64.RawURLEncoding.DecodeString(r.PathValue("checksum"))
if err != nil || !hmac.Equal(checksum, br.makePublicMediaChecksum(contentURI)) {
http.Error(w, "invalid base64 in checksum", http.StatusBadRequest)
return
Expand Down
5 changes: 2 additions & 3 deletions bridgev2/matrixinterface.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,9 @@ package bridgev2
import (
"context"
"io"
"net/http"
"time"

"github.com/gorilla/mux"

"maunium.net/go/mautrix"
"maunium.net/go/mautrix/bridge/status"
"maunium.net/go/mautrix/bridgev2/database"
Expand Down Expand Up @@ -56,7 +55,7 @@ type MatrixConnector interface {

type MatrixConnectorWithServer interface {
GetPublicAddress() string
GetRouter() *mux.Router
GetRouter() *http.ServeMux
}

type MatrixConnectorWithPublicMedia interface {
Expand Down
Loading

0 comments on commit dd17c7b

Please sign in to comment.