Replay Protection
#
In-Memory Replay Protection
#
package main
import (
"context"
"errors"
"log"
"net/http"
"sync"
"time"
"github.com/SecureWebhookToken/swt"
)
// MemoryReplayChecker implements replay protection using in-memory storage
type MemoryReplayChecker struct {
mu sync.RWMutex
seen map[string]time.Time
ttl time.Duration
ticker *time.Ticker
done chan struct{}
}
// NewMemoryReplayChecker creates a new in-memory replay checker
func NewMemoryReplayChecker(ttl time.Duration) *MemoryReplayChecker {
checker := &MemoryReplayChecker{
seen: make(map[string]time.Time),
ttl: ttl,
ticker: time.NewTicker(ttl / 2),
done: make(chan struct{}),
}
// Start cleanup goroutine
go checker.cleanup()
return checker
}
// CheckAndRecord implements the swt.ReplayChecker interface
func (m *MemoryReplayChecker) CheckAndRecord(ctx context.Context, jti string) error {
m.mu.Lock()
defer m.mu.Unlock()
// Check if token has been seen before
if _, exists := m.seen[jti]; exists {
return errors.New("replay attack detected: token already used")
}
// Record the token
m.seen[jti] = time.Now()
return nil
}
// cleanup removes expired JTIs
func (m *MemoryReplayChecker) cleanup() {
for {
select {
case <-m.ticker.C:
m.mu.Lock()
now := time.Now()
for jti, timestamp := range m.seen {
if now.Sub(timestamp) > m.ttl {
delete(m.seen, jti)
}
}
m.mu.Unlock()
case <-m.done:
m.ticker.Stop()
return
}
}
}
// Close stops the cleanup goroutine
func (m *MemoryReplayChecker) Close() {
close(m.done)
}
func main() {
secretKey := []byte("your-secret-key-min-256-bits")
// Create replay checker with 15 minute TTL
replayChecker := NewMemoryReplayChecker(15 * time.Minute)
defer replayChecker.Close()
// Configure handler with replay protection
opts := &swt.HandlerOptions{
MaxBodySize: 5 * 1024 * 1024,
ReplayChecker: replayChecker,
}
webhookHandler := swt.NewHandlerFunc(secretKey, handleWebhook, opts)
http.HandleFunc("/webhook", webhookHandler)
log.Println("Webhook server with replay protection on :8080")
if err := http.ListenAndServe(":8080", nil); err != nil {
log.Fatal(err)
}
}
func handleWebhook(token *swt.SecureWebhookToken, body []byte) error {
log.Printf("Processing webhook: %s (jti: %s)", token.Webhook().Event, token.ID())
return nil
}
Redis-based Replay Protection
#
package main
import (
"context"
"errors"
"fmt"
"log"
"net/http"
"time"
"github.com/SecureWebhookToken/swt"
"github.com/redis/go-redis/v9"
)
// RedisReplayChecker implements replay protection using Redis
type RedisReplayChecker struct {
client *redis.Client
ttl time.Duration
}
// NewRedisReplayChecker creates a new Redis-based replay checker
func NewRedisReplayChecker(addr string, ttl time.Duration) *RedisReplayChecker {
client := redis.NewClient(&redis.Options{
Addr: addr,
Password: "", // Set if needed
DB: 0,
})
return &RedisReplayChecker{
client: client,
ttl: ttl,
}
}
// CheckAndRecord implements the swt.ReplayChecker interface
func (r *RedisReplayChecker) CheckAndRecord(ctx context.Context, jti string) error {
key := fmt.Sprintf("swt:jti:%s", jti)
// Try to set the key only if it doesn't exist (NX flag)
result, err := r.client.SetNX(ctx, key, "1", r.ttl).Result()
if err != nil {
return fmt.Errorf("redis error: %w", err)
}
if !result {
return errors.New("replay attack detected: token already used")
}
return nil
}
// Close closes the Redis connection
func (r *RedisReplayChecker) Close() error {
return r.client.Close()
}
func main() {
secretKey := []byte("your-secret-key-min-256-bits")
// Create Redis replay checker
replayChecker := NewRedisReplayChecker("localhost:6379", 15*time.Minute)
defer replayChecker.Close()
// Test Redis connection
ctx := context.Background()
if err := replayChecker.client.Ping(ctx).Err(); err != nil {
log.Fatalf("Redis connection failed: %v", err)
}
opts := &swt.HandlerOptions{
MaxBodySize: 5 * 1024 * 1024,
ReplayChecker: replayChecker,
}
webhookHandler := swt.NewHandlerFunc(secretKey, handleWebhook, opts)
http.HandleFunc("/webhook", webhookHandler)
log.Println("Webhook server with Redis replay protection on :8080")
if err := http.ListenAndServe(":8080", nil); err != nil {
log.Fatal(err)
}
}
func handleWebhook(token *swt.SecureWebhookToken, body []byte) error {
log.Printf("Processing webhook: %s (jti: %s)", token.Webhook().Event, token.ID())
return nil
}
Database-based Replay Protection (PostgreSQL)
#
package main
import (
"context"
"database/sql"
"errors"
"log"
"net/http"
"time"
"github.com/SecureWebhookToken/swt"
_ "github.com/lib/pq"
)
// PostgresReplayChecker implements replay protection using PostgreSQL
type PostgresReplayChecker struct {
db *sql.DB
ttl time.Duration
}
// NewPostgresReplayChecker creates a new PostgreSQL-based replay checker
func NewPostgresReplayChecker(connStr string, ttl time.Duration) (*PostgresReplayChecker, error) {
db, err := sql.Open("postgres", connStr)
if err != nil {
return nil, err
}
// Create table if it doesn't exist
_, err = db.Exec(`
CREATE TABLE IF NOT EXISTS webhook_tokens (
jti VARCHAR(255) PRIMARY KEY,
created_at TIMESTAMP NOT NULL DEFAULT NOW(),
expires_at TIMESTAMP NOT NULL
);
CREATE INDEX IF NOT EXISTS idx_expires_at ON webhook_tokens(expires_at);
`)
if err != nil {
return nil, err
}
checker := &PostgresReplayChecker{
db: db,
ttl: ttl,
}
// Start cleanup goroutine
go checker.cleanupExpired()
return checker, nil
}
// CheckAndRecord implements the swt.ReplayChecker interface
func (p *PostgresReplayChecker) CheckAndRecord(ctx context.Context, jti string) error {
expiresAt := time.Now().Add(p.ttl)
// Try to insert the JTI
_, err := p.db.ExecContext(ctx, `
INSERT INTO webhook_tokens (jti, expires_at)
VALUES ($1, $2)
`, jti, expiresAt)
if err != nil {
// Check if it's a duplicate key error
if isDuplicateKeyError(err) {
return errors.New("replay attack detected: token already used")
}
return fmt.Errorf("database error: %w", err)
}
return nil
}
// cleanupExpired removes expired tokens periodically
func (p *PostgresReplayChecker) cleanupExpired() {
ticker := time.NewTicker(5 * time.Minute)
defer ticker.Stop()
for range ticker.C {
_, err := p.db.Exec(`
DELETE FROM webhook_tokens
WHERE expires_at < NOW()
`)
if err != nil {
log.Printf("Failed to cleanup expired tokens: %v", err)
}
}
}
// Close closes the database connection
func (p *PostgresReplayChecker) Close() error {
return p.db.Close()
}
func isDuplicateKeyError(err error) bool {
// PostgreSQL duplicate key error code: 23505
return err != nil && (err.Error() == "pq: duplicate key value violates unique constraint \"webhook_tokens_pkey\"" ||
strings.Contains(err.Error(), "23505"))
}
func main() {
secretKey := []byte("your-secret-key-min-256-bits")
// Database connection string
connStr := "postgres://user:password@localhost/webhooks?sslmode=disable"
// Create PostgreSQL replay checker
replayChecker, err := NewPostgresReplayChecker(connStr, 15*time.Minute)
if err != nil {
log.Fatalf("Failed to create replay checker: %v", err)
}
defer replayChecker.Close()
opts := &swt.HandlerOptions{
MaxBodySize: 5 * 1024 * 1024,
ReplayChecker: replayChecker,
}
webhookHandler := swt.NewHandlerFunc(secretKey, handleWebhook, opts)
http.HandleFunc("/webhook", webhookHandler)
log.Println("Webhook server with PostgreSQL replay protection on :8080")
if err := http.ListenAndServe(":8080", nil); err != nil {
log.Fatal(err)
}
}
func handleWebhook(token *swt.SecureWebhookToken, body []byte) error {
log.Printf("Processing webhook: %s (jti: %s)", token.Webhook().Event, token.ID())
return nil
}
Testing Replay Protection
#
package main
import (
"bytes"
"context"
"fmt"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/SecureWebhookToken/swt"
)
func TestReplayProtection(t *testing.T) {
secretKey := []byte("test-secret-key-for-testing-only")
// Create replay checker
replayChecker := NewMemoryReplayChecker(5 * time.Minute)
defer replayChecker.Close()
// Create handler with replay protection
opts := &swt.HandlerOptions{
ReplayChecker: replayChecker,
}
webhookHandler := swt.NewHandlerFunc(secretKey, func(token *swt.SecureWebhookToken, body []byte) error {
return nil
}, opts)
// Create test server
server := httptest.NewServer(webhookHandler)
defer server.Close()
// Create webhook request
req := &swt.Request{
URL: server.URL,
Issuer: "test-service",
Event: "test.event",
Data: []byte(`{"test":"data"}`),
}
// Build HTTP request
httpReq, err := req.Build(secretKey)
if err != nil {
t.Fatalf("Failed to build request: %v", err)
}
// First request should succeed
client := &http.Client{}
resp, err := client.Do(httpReq)
if err != nil {
t.Fatalf("First request failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Errorf("Expected status 200, got %d", resp.StatusCode)
}
// Second request with same token should fail (replay attack)
// Recreate the request with the same body
httpReq2, err := http.NewRequest("POST", server.URL, bytes.NewReader(httpReq.Body))
if err != nil {
t.Fatalf("Failed to create second request: %v", err)
}
httpReq2.Header = httpReq.Header
resp2, err := client.Do(httpReq2)
if err != nil {
t.Fatalf("Second request failed: %v", err)
}
defer resp2.Body.Close()
if resp2.StatusCode != http.StatusUnauthorized {
t.Errorf("Expected status 401 for replay attack, got %d", resp2.StatusCode)
}
fmt.Println("Replay protection test passed!")
}