mirror of
https://github.com/henrygd/beszel.git
synced 2025-12-17 02:36:17 +01:00
- Add version exchange between hub and agent. - Introduce ConnectionManager for managing WebSocket and SSH connections. - Implement fingerprint generation and storage in agent. - Create expiry map package to store universal tokens. - Update config.yml configuration to include tokens. - Enhance system management with new methods for handling system states and alerts. - Update front-end components to support token / fingerprint management features. - Introduce utility functions for token generation and hub URL retrieval. Co-authored-by: nhas <jordanatararimu@gmail.com>
248 lines
7.3 KiB
Go
248 lines
7.3 KiB
Go
package hub
|
|
|
|
import (
|
|
"beszel/internal/common"
|
|
"beszel/internal/hub/expirymap"
|
|
"beszel/internal/hub/ws"
|
|
"errors"
|
|
"fmt"
|
|
"net"
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/blang/semver"
|
|
"github.com/lxzan/gws"
|
|
"github.com/pocketbase/dbx"
|
|
"github.com/pocketbase/pocketbase/core"
|
|
)
|
|
|
|
// tokenMap maps tokens to user IDs for universal tokens
|
|
var tokenMap *expirymap.ExpiryMap[string]
|
|
|
|
type agentConnectRequest struct {
|
|
token string
|
|
agentSemVer semver.Version
|
|
// for universal token
|
|
isUniversalToken bool
|
|
userId string
|
|
remoteAddr string
|
|
}
|
|
|
|
// validateAgentHeaders validates the required headers from agent connection requests.
|
|
func (h *Hub) validateAgentHeaders(headers http.Header) (string, string, error) {
|
|
token := headers.Get("X-Token")
|
|
agentVersion := headers.Get("X-Beszel")
|
|
|
|
if agentVersion == "" || token == "" || len(token) > 512 {
|
|
return "", "", errors.New("")
|
|
}
|
|
return token, agentVersion, nil
|
|
}
|
|
|
|
// getFingerprintRecord retrieves fingerprint data from the database by token.
|
|
func (h *Hub) getFingerprintRecord(token string, recordData *ws.FingerprintRecord) error {
|
|
err := h.DB().NewQuery("SELECT id, system, fingerprint, token FROM fingerprints WHERE token = {:token}").
|
|
Bind(dbx.Params{
|
|
"token": token,
|
|
}).
|
|
One(recordData)
|
|
return err
|
|
}
|
|
|
|
// sendResponseError sends an HTTP error response with the given status code and message.
|
|
func sendResponseError(res http.ResponseWriter, code int, message string) error {
|
|
res.WriteHeader(code)
|
|
if message != "" {
|
|
res.Write([]byte(message))
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// handleAgentConnect handles the incoming connection request from the agent.
|
|
func (h *Hub) handleAgentConnect(e *core.RequestEvent) error {
|
|
if err := h.agentConnect(e.Request, e.Response); err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// agentConnect handles agent connection requests, validating credentials and upgrading to WebSocket.
|
|
func (h *Hub) agentConnect(req *http.Request, res http.ResponseWriter) (err error) {
|
|
var agentConnectRequest agentConnectRequest
|
|
var agentVersion string
|
|
// check if user agent and token are valid
|
|
agentConnectRequest.token, agentVersion, err = h.validateAgentHeaders(req.Header)
|
|
if err != nil {
|
|
return sendResponseError(res, http.StatusUnauthorized, "")
|
|
}
|
|
|
|
// Pull fingerprint from database matching token
|
|
var fpRecord ws.FingerprintRecord
|
|
err = h.getFingerprintRecord(agentConnectRequest.token, &fpRecord)
|
|
|
|
// if no existing record, check if token is a universal token
|
|
if err != nil {
|
|
if err = checkUniversalToken(&agentConnectRequest); err == nil {
|
|
// if this is a universal token, set the remote address and new record token
|
|
agentConnectRequest.remoteAddr = getRealIP(req)
|
|
fpRecord.Token = agentConnectRequest.token
|
|
}
|
|
}
|
|
|
|
// If no matching token, return unauthorized
|
|
if err != nil {
|
|
return sendResponseError(res, http.StatusUnauthorized, "Invalid token")
|
|
}
|
|
|
|
// Validate agent version
|
|
agentConnectRequest.agentSemVer, err = semver.Parse(agentVersion)
|
|
if err != nil {
|
|
return sendResponseError(res, http.StatusUnauthorized, "Invalid agent version")
|
|
}
|
|
|
|
// Upgrade connection to WebSocket
|
|
conn, err := ws.GetUpgrader().Upgrade(res, req)
|
|
if err != nil {
|
|
return sendResponseError(res, http.StatusInternalServerError, "WebSocket upgrade failed")
|
|
}
|
|
|
|
go h.verifyWsConn(conn, agentConnectRequest, fpRecord)
|
|
|
|
return nil
|
|
}
|
|
|
|
// verifyWsConn verifies the WebSocket connection using agent's fingerprint and SSH key signature.
|
|
func (h *Hub) verifyWsConn(conn *gws.Conn, acr agentConnectRequest, fpRecord ws.FingerprintRecord) (err error) {
|
|
wsConn := ws.NewWsConnection(conn)
|
|
// must be set before the read loop
|
|
conn.Session().Store("wsConn", wsConn)
|
|
|
|
// make sure connection is closed if there is an error
|
|
defer func() {
|
|
if err != nil {
|
|
wsConn.Close()
|
|
h.Logger().Error("WebSocket error", "error", err, "system", fpRecord.SystemId)
|
|
}
|
|
}()
|
|
|
|
go conn.ReadLoop()
|
|
|
|
signer, err := h.GetSSHKey("")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
agentFingerprint, err := wsConn.GetFingerprint(acr.token, signer, acr.isUniversalToken)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Create system if using universal token
|
|
if acr.isUniversalToken {
|
|
if acr.userId == "" {
|
|
return errors.New("token user not found")
|
|
}
|
|
fpRecord.SystemId, err = h.createSystemFromAgentData(&acr, agentFingerprint)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to create system from universal token: %w", err)
|
|
}
|
|
}
|
|
|
|
switch {
|
|
// If no current fingerprint, update with new fingerprint (first time connecting)
|
|
case fpRecord.Fingerprint == "":
|
|
if err := h.SetFingerprint(&fpRecord, agentFingerprint.Fingerprint); err != nil {
|
|
return err
|
|
}
|
|
// Abort if fingerprint exists but doesn't match (different machine)
|
|
case fpRecord.Fingerprint != agentFingerprint.Fingerprint:
|
|
return errors.New("fingerprint mismatch")
|
|
}
|
|
|
|
return h.sm.AddWebSocketSystem(fpRecord.SystemId, acr.agentSemVer, wsConn)
|
|
}
|
|
|
|
// createSystemFromAgentData creates a new system record using data from the agent
|
|
func (h *Hub) createSystemFromAgentData(acr *agentConnectRequest, agentFingerprint common.FingerprintResponse) (recordId string, err error) {
|
|
systemsCollection, err := h.FindCollectionByNameOrId("systems")
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to find systems collection: %w", err)
|
|
}
|
|
// separate port from address
|
|
if agentFingerprint.Hostname == "" {
|
|
agentFingerprint.Hostname = acr.remoteAddr
|
|
}
|
|
if agentFingerprint.Port == "" {
|
|
agentFingerprint.Port = "45876"
|
|
}
|
|
// create new record
|
|
systemRecord := core.NewRecord(systemsCollection)
|
|
systemRecord.Set("name", agentFingerprint.Hostname)
|
|
systemRecord.Set("host", acr.remoteAddr)
|
|
systemRecord.Set("port", agentFingerprint.Port)
|
|
systemRecord.Set("users", []string{acr.userId})
|
|
|
|
return systemRecord.Id, h.Save(systemRecord)
|
|
}
|
|
|
|
// SetFingerprint updates the fingerprint for a given record ID.
|
|
func (h *Hub) SetFingerprint(fpRecord *ws.FingerprintRecord, fingerprint string) (err error) {
|
|
// // can't use raw query here because it doesn't trigger SSE
|
|
var record *core.Record
|
|
switch fpRecord.Id {
|
|
case "":
|
|
// create new record for universal token
|
|
collection, _ := h.FindCachedCollectionByNameOrId("fingerprints")
|
|
record = core.NewRecord(collection)
|
|
record.Set("system", fpRecord.SystemId)
|
|
default:
|
|
record, err = h.FindRecordById("fingerprints", fpRecord.Id)
|
|
}
|
|
if err != nil {
|
|
return err
|
|
}
|
|
record.Set("token", fpRecord.Token)
|
|
record.Set("fingerprint", fingerprint)
|
|
return h.SaveNoValidate(record)
|
|
}
|
|
|
|
func getTokenMap() *expirymap.ExpiryMap[string] {
|
|
if tokenMap == nil {
|
|
tokenMap = expirymap.New[string](time.Hour)
|
|
}
|
|
return tokenMap
|
|
}
|
|
|
|
func checkUniversalToken(acr *agentConnectRequest) (err error) {
|
|
if tokenMap == nil {
|
|
tokenMap = expirymap.New[string](time.Hour)
|
|
}
|
|
acr.userId, acr.isUniversalToken = tokenMap.GetOk(acr.token)
|
|
if !acr.isUniversalToken {
|
|
return errors.New("invalid token")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// getRealIP attempts to extract the real IP address from the request headers.
|
|
func getRealIP(r *http.Request) string {
|
|
if ip := r.Header.Get("CF-Connecting-IP"); ip != "" {
|
|
return ip
|
|
}
|
|
if ip := r.Header.Get("X-Forwarded-For"); ip != "" {
|
|
// X-Forwarded-For can contain a comma-separated list: "client_ip, proxy1, proxy2"
|
|
// Take the first one
|
|
ips := strings.Split(ip, ",")
|
|
if len(ips) > 0 {
|
|
return strings.TrimSpace(ips[0])
|
|
}
|
|
}
|
|
// Fallback to RemoteAddr
|
|
ip, _, err := net.SplitHostPort(r.RemoteAddr)
|
|
if err != nil {
|
|
return r.RemoteAddr
|
|
}
|
|
return ip
|
|
}
|