Skip to content

Commit

Permalink
Move ws challenge to router
Browse files Browse the repository at this point in the history
  • Loading branch information
NHAS committed Jun 10, 2024
1 parent ab97da5 commit aaffa42
Show file tree
Hide file tree
Showing 6 changed files with 171 additions and 215 deletions.
2 changes: 2 additions & 0 deletions internal/router/init.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ import (
var (
lock sync.RWMutex
cancel = make(chan bool)

Verifier = NewChallenger()
)

func Setup(errorChan chan<- error, iptables bool) (err error) {
Expand Down
155 changes: 155 additions & 0 deletions internal/router/session_manager.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
package router

import (
"crypto/subtle"
"fmt"
"log"
"net/http"
"sync"
"time"

"github.com/NHAS/wag/internal/data"
"github.com/NHAS/wag/internal/users"
"github.com/NHAS/wag/internal/utils"
"github.com/gorilla/websocket"
)

type wsConnWrapper struct {
*websocket.Conn
wait chan interface{}
}

func (ws *wsConnWrapper) Await() <-chan interface{} {
return ws.wait
}

func (ws *wsConnWrapper) Close() error {
close(ws.wait)
return ws.Conn.Close()
}

type Challenger struct {
sync.RWMutex
connections map[string]*wsConnWrapper

upgrader websocket.Upgrader
}

func NewChallenger() *Challenger {
r := &Challenger{
connections: make(map[string]*wsConnWrapper),
upgrader: websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
CheckOrigin: func(r *http.Request) bool {
domain, err := data.GetDomain()
if err != nil {
log.Println("was unable to get the wag domain: ", err)
return false
}

valid := r.Header.Get("Origin") == domain
if !valid {
log.Printf("websocket origin does not equal expected value: %q != %q", r.Header.Get("Origin"), domain)
}

return valid
},
},
}

return r
}

func (c *Challenger) Challenge(address string) error {
c.RLock()
defer c.RUnlock()

conn, ok := c.connections[address]
if !ok {
return fmt.Errorf("no connection found for device: %s", address)
}

err := conn.SetWriteDeadline(time.Now().Add(2 * time.Second))
if err != nil {
return err
}

err = conn.WriteJSON("challenge")
if err != nil {
return err
}

err = conn.SetReadDeadline(time.Now().Add(2 * time.Second))
if err != nil {
return err
}

msg := struct{ Challenge string }{}
err = conn.ReadJSON(&msg)
if err != nil {
return err
}

deviceDetails, err := data.GetDeviceByAddress(address)
if err != nil {
return fmt.Errorf("failed to get device address for ws challenge: %s", err)
}

if subtle.ConstantTimeCompare([]byte(deviceDetails.Challenge), []byte(msg.Challenge)) != 1 {
return fmt.Errorf("challenge does not match")
}

return nil
}

func (c *Challenger) WS(w http.ResponseWriter, r *http.Request) {
remoteAddress := utils.GetIPFromRequest(r)
user, err := users.GetUserFromAddress(remoteAddress)
if err != nil {
log.Println("unknown", remoteAddress, "Could not find user: ", err)
http.Error(w, "Server Error", http.StatusInternalServerError)
return
}

// Upgrade HTTP connection to WebSocket connection
_c, err := c.upgrader.Upgrade(w, r, nil)
if err != nil {
log.Println(user.Username, remoteAddress, "failed to create websocket challenger:", err)
http.Error(w, "Server Error", http.StatusInternalServerError)
return
}

conn := &wsConnWrapper{Conn: _c, wait: make(chan interface{})}

defer func() {
if conn != nil {
conn.Close()
}

c.Lock()
delete(c.connections, remoteAddress.String())
c.Unlock()

}()

c.Lock()

if prev, ok := c.connections[remoteAddress.String()]; ok && prev != nil {
prev.Close()
}

c.connections[remoteAddress.String()] = conn
c.Unlock()

err = c.Challenge(remoteAddress.String())
if err != nil {
log.Printf("client did not complete ws challenge: %s", err)
return
}

select {
case <-cancel:
case <-conn.Await():
}
}
18 changes: 13 additions & 5 deletions internal/router/statemachine.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,19 +99,27 @@ func deviceChanges(_ string, current, previous data.Device, et data.EventType) e
return fmt.Errorf("cannot get lockout: %s", err)
}

if current.Endpoint.String() != previous.Endpoint.String() {

// Will take at most 4 seconds
err := Verifier.Challenge(current.Address)
if err != nil {
log.Printf("%s:%s failed to pass websockets challenge: %s", current.Username, current.Address, err)
err := Deauthenticate(current.Address)
if err != nil {
return fmt.Errorf("cannot deauthenticate device %s: %s", current.Address, err)
}
}
}

if current.Attempts > lockout || // If the number of authentication attempts on a device has exceeded the max
current.Endpoint.String() != previous.Endpoint.String() || // If the client ip has changed
current.Authorised.IsZero() { // If we've explicitly deauthorised a device

var reasons []string
if current.Attempts > lockout {
reasons = []string{fmt.Sprintf("exceeded lockout (%d)", current.Attempts)}
}

if current.Endpoint.String() != previous.Endpoint.String() {
reasons = append(reasons, "endpoint changed")
}

if current.Authorised.IsZero() {
reasons = append(reasons, "session terminated")
}
Expand Down
2 changes: 0 additions & 2 deletions internal/webserver/authenticators/authenticators.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@ var (
types.Pam: new(Pam),
}
lck sync.RWMutex

ChallengesManager = NewChallenger()
)

func GetMethod(method string) (Authenticator, bool) {
Expand Down
Loading

0 comments on commit aaffa42

Please sign in to comment.