You've already forked hastebin-ansi
mirror of
https://github.com/armbian/hastebin-ansi.git
synced 2026-01-06 12:30:55 -08:00
server: add support for rate limiting
This commit is contained in:
committed by
M. Efe Çetin
parent
4766ac3370
commit
5331e14954
24
cmd/main.go
24
cmd/main.go
@@ -13,7 +13,8 @@ import (
|
||||
"github.com/armbian/ansi-hastebin/internal/keygenerator"
|
||||
"github.com/armbian/ansi-hastebin/internal/server"
|
||||
"github.com/armbian/ansi-hastebin/internal/storage"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
func handleConfig(location string) (*config.Config, storage.Storage, keygenerator.KeyGenerator) {
|
||||
@@ -35,7 +36,7 @@ func handleConfig(location string) (*config.Config, storage.Storage, keygenerato
|
||||
case "s3":
|
||||
pasteStorage = storage.NewS3Storage(cfg.Storage.Host, cfg.Storage.Port, cfg.Storage.Username, cfg.Storage.Password, cfg.Storage.AWSRegion, cfg.Storage.Bucket)
|
||||
default:
|
||||
logrus.Fatalf("Unknown storage type: %s", cfg.Storage.Type)
|
||||
log.Fatal().Str("storage_type", cfg.Storage.Type).Msg("Unknown storage type")
|
||||
return nil, nil, nil
|
||||
}
|
||||
|
||||
@@ -43,17 +44,17 @@ func handleConfig(location string) (*config.Config, storage.Storage, keygenerato
|
||||
for _, doc := range cfg.Documents {
|
||||
file, err := os.OpenFile(doc.Path, os.O_RDONLY, 0644)
|
||||
if err != nil {
|
||||
logrus.WithError(err).WithField("path", doc.Path).Fatal("Failed to open document")
|
||||
log.Fatal().Err(err).Str("path", doc.Path).Msg("Failed to open document")
|
||||
}
|
||||
|
||||
content, err := io.ReadAll(file)
|
||||
if err != nil {
|
||||
logrus.WithError(err).WithField("path", doc.Path).Fatal("Failed to read document")
|
||||
log.Fatal().Err(err).Str("path", doc.Path).Msg("Failed to read document")
|
||||
}
|
||||
file.Close()
|
||||
|
||||
if err := pasteStorage.Set(doc.Key, string(content), false); err != nil {
|
||||
logrus.WithError(err).WithField("key", doc.Key).Fatal("Failed to set document")
|
||||
log.Fatal().Err(err).Str("key", doc.Key).Msg("Failed to set document")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -65,10 +66,21 @@ func handleConfig(location string) (*config.Config, storage.Storage, keygenerato
|
||||
case "phonetic":
|
||||
keyGenerator = keygenerator.NewPhoneticKeyGenerator()
|
||||
default:
|
||||
logrus.Fatalf("Unknown key generator: %s", cfg.KeyGenerator)
|
||||
log.Fatal().Str("key_generator", cfg.KeyGenerator).Msg("Unknown key generator")
|
||||
return nil, nil, nil
|
||||
}
|
||||
|
||||
// Adjust logger
|
||||
logLevel, err := zerolog.ParseLevel(cfg.Logging.Level)
|
||||
if err != nil {
|
||||
log.Fatal().Err(err).Str("level", cfg.Logging.Level).Msg("Failed to parse log level")
|
||||
}
|
||||
log.Logger = log.Level(logLevel)
|
||||
|
||||
if cfg.Logging.Colorize {
|
||||
log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stdout})
|
||||
}
|
||||
|
||||
return cfg, pasteStorage, keyGenerator
|
||||
}
|
||||
|
||||
|
||||
@@ -15,6 +15,11 @@ documents:
|
||||
- key: "about"
|
||||
path: "./about.md"
|
||||
|
||||
rate_limiting:
|
||||
enable: true
|
||||
limit: 500
|
||||
window: 15
|
||||
|
||||
logging:
|
||||
level: "info"
|
||||
type: "text"
|
||||
|
||||
@@ -5,14 +5,27 @@ import (
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/rs/zerolog/log"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
type LoggingConfig struct {
|
||||
Level string `yaml:"level"`
|
||||
Type string `yaml:"type"`
|
||||
Colorize bool `yaml:"colorize"`
|
||||
// Level is the logging level to use
|
||||
Level string `yaml:"level"`
|
||||
|
||||
// Colorize is a flag to enable colorized output
|
||||
Colorize bool `yaml:"colorize"`
|
||||
}
|
||||
|
||||
type RateLimitingConfig struct {
|
||||
// Enable is a flag to enable rate limiting
|
||||
Enable bool `yaml:"enable"`
|
||||
|
||||
// Limit is the maximum number of requests
|
||||
Limit int `yaml:"limit"`
|
||||
|
||||
// Window is the time window to limit requests
|
||||
Window int `yaml:"window"`
|
||||
}
|
||||
|
||||
type StorageConfig struct {
|
||||
@@ -94,6 +107,9 @@ type Config struct {
|
||||
// Logging is the logging configuration
|
||||
Logging LoggingConfig `yaml:"logging"`
|
||||
|
||||
// RateLimiting is the rate limiting configuration
|
||||
RateLimiting RateLimitingConfig `yaml:"rate_limiting"`
|
||||
|
||||
// Documents is the list of documents to load statically
|
||||
Documents []DocumentConfig `yaml:"documents"`
|
||||
}
|
||||
@@ -112,7 +128,6 @@ var DefaultConfig = &Config{
|
||||
},
|
||||
Logging: LoggingConfig{
|
||||
Level: "info",
|
||||
Type: "text",
|
||||
},
|
||||
Documents: []DocumentConfig{
|
||||
{
|
||||
@@ -129,12 +144,12 @@ func NewConfig(configFile string) *Config {
|
||||
// Read the configuration file
|
||||
data, err := os.ReadFile(configFile)
|
||||
if err != nil && !os.IsNotExist(err) {
|
||||
logrus.WithError(err).Fatal("Failed to read configuration file")
|
||||
log.Fatal().Err(err).Msg("Failed to read configuration file")
|
||||
}
|
||||
|
||||
// Unmarshal the configuration file
|
||||
if err := yaml.Unmarshal(data, cfg); err != nil {
|
||||
logrus.WithError(err).Fatal("Failed to unmarshal configuration file")
|
||||
log.Fatal().Err(err).Msg("Failed to unmarshal configuration file")
|
||||
}
|
||||
|
||||
// Override with environment variables
|
||||
@@ -145,7 +160,7 @@ func NewConfig(configFile string) *Config {
|
||||
if port := os.Getenv("PORT"); port != "" {
|
||||
portInt, err := strconv.Atoi(port)
|
||||
if err != nil {
|
||||
logrus.WithError(err).Fatal("Failed to parse PORT environment variable")
|
||||
log.Fatal().Err(err).Msg("Failed to parse PORT environment variable")
|
||||
}
|
||||
cfg.Port = portInt
|
||||
}
|
||||
@@ -153,7 +168,7 @@ func NewConfig(configFile string) *Config {
|
||||
if keyLength := os.Getenv("KEY_LENGTH"); keyLength != "" {
|
||||
keyLengthInt, err := strconv.Atoi(keyLength)
|
||||
if err != nil {
|
||||
logrus.WithError(err).Fatal("Failed to parse KEY_LENGTH environment variable")
|
||||
log.Fatal().Err(err).Msg("Failed to parse KEY_LENGTH environment variable")
|
||||
}
|
||||
cfg.KeyLength = keyLengthInt
|
||||
}
|
||||
@@ -161,7 +176,7 @@ func NewConfig(configFile string) *Config {
|
||||
if maxLength := os.Getenv("MAX_LENGTH"); maxLength != "" {
|
||||
maxLengthInt, err := strconv.Atoi(maxLength)
|
||||
if err != nil {
|
||||
logrus.WithError(err).Fatal("Failed to parse MAX_LENGTH environment variable")
|
||||
log.Fatal().Err(err).Msg("Failed to parse MAX_LENGTH environment variable")
|
||||
}
|
||||
cfg.MaxLength = maxLengthInt
|
||||
}
|
||||
@@ -169,7 +184,7 @@ func NewConfig(configFile string) *Config {
|
||||
if staticMaxAge := os.Getenv("STATIC_MAX_AGE"); staticMaxAge != "" {
|
||||
staticMaxAgeInt, err := strconv.Atoi(staticMaxAge)
|
||||
if err != nil {
|
||||
logrus.WithError(err).Fatal("Failed to parse STATIC_MAX_AGE environment variable")
|
||||
log.Fatal().Err(err).Msg("Failed to parse STATIC_MAX_AGE environment variable")
|
||||
}
|
||||
cfg.StaticMaxAge = staticMaxAgeInt
|
||||
}
|
||||
@@ -177,7 +192,7 @@ func NewConfig(configFile string) *Config {
|
||||
if recompressStaticAssets := os.Getenv("RECOMPRESS_STATIC_ASSETS"); recompressStaticAssets != "" {
|
||||
recompressStaticAssetsBool, err := strconv.ParseBool(recompressStaticAssets)
|
||||
if err != nil {
|
||||
logrus.WithError(err).Fatal("Failed to parse RECOMPRESS_STATIC_ASSETS environment variable")
|
||||
log.Fatal().Err(err).Msg("Failed to parse RECOMPRESS_STATIC_ASSETS environment variable")
|
||||
}
|
||||
cfg.RecompressStaticAssets = recompressStaticAssetsBool
|
||||
}
|
||||
@@ -197,7 +212,7 @@ func NewConfig(configFile string) *Config {
|
||||
if storagePort := os.Getenv("STORAGE_PORT"); storagePort != "" {
|
||||
storagePortInt, err := strconv.Atoi(storagePort)
|
||||
if err != nil {
|
||||
logrus.WithError(err).Fatal("Failed to parse STORAGE_PORT environment variable")
|
||||
log.Fatal().Err(err).Msg("Failed to parse STORAGE_PORT environment variable")
|
||||
}
|
||||
cfg.Storage.Port = storagePortInt
|
||||
}
|
||||
@@ -230,18 +245,41 @@ func NewConfig(configFile string) *Config {
|
||||
cfg.Logging.Level = loggingLevel
|
||||
}
|
||||
|
||||
if loggingType := os.Getenv("LOGGING_TYPE"); loggingType != "" {
|
||||
cfg.Logging.Type = loggingType
|
||||
}
|
||||
|
||||
if loggingColorize := os.Getenv("LOGGING_COLORIZE"); loggingColorize != "" {
|
||||
loggingColorizeBool, err := strconv.ParseBool(loggingColorize)
|
||||
if err != nil {
|
||||
logrus.WithError(err).Fatal("Failed to parse LOGGING_COLORIZE environment variable")
|
||||
log.Fatal().Err(err).Msg("Failed to parse LOGGING_COLORIZE environment variable")
|
||||
}
|
||||
cfg.Logging.Colorize = loggingColorizeBool
|
||||
}
|
||||
|
||||
if rateLimitingEnable := os.Getenv("RATE_LIMITING_ENABLE"); rateLimitingEnable != "" {
|
||||
rateLimitingEnableBool, err := strconv.ParseBool(rateLimitingEnable)
|
||||
if err != nil {
|
||||
log.Fatal().Err(err).Msg("Failed to parse RATE_LIMITING_ENABLE environment variable")
|
||||
}
|
||||
|
||||
cfg.RateLimiting.Enable = rateLimitingEnableBool
|
||||
}
|
||||
|
||||
if rateLimitingLimit := os.Getenv("RATE_LIMITING_LIMIT"); rateLimitingLimit != "" {
|
||||
rateLimitingLimitInt, err := strconv.Atoi(rateLimitingLimit)
|
||||
if err != nil {
|
||||
log.Fatal().Err(err).Msg("Failed to parse RATE_LIMITING_LIMIT environment variable")
|
||||
}
|
||||
|
||||
cfg.RateLimiting.Limit = rateLimitingLimitInt
|
||||
}
|
||||
|
||||
if rateLimitingWindow := os.Getenv("RATE_LIMITING_WINDOW"); rateLimitingWindow != "" {
|
||||
rateLimitingWindowInt, err := strconv.Atoi(rateLimitingWindow)
|
||||
if err != nil {
|
||||
log.Fatal().Err(err).Msg("Failed to parse RATE_LIMITING_WINDOW environment variable")
|
||||
}
|
||||
|
||||
cfg.RateLimiting.Window = rateLimitingWindowInt
|
||||
}
|
||||
|
||||
// Walk environment variables for documents
|
||||
for _, env := range os.Environ() {
|
||||
if len(env) > 10 && env[:10] == "DOCUMENTS_" {
|
||||
@@ -292,9 +330,5 @@ func NewConfig(configFile string) *Config {
|
||||
cfg.Logging.Level = DefaultConfig.Logging.Level
|
||||
}
|
||||
|
||||
if cfg.Logging.Type == "" {
|
||||
cfg.Logging.Type = DefaultConfig.Logging.Type
|
||||
}
|
||||
|
||||
return cfg
|
||||
}
|
||||
|
||||
@@ -4,44 +4,47 @@ import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewConfig_DefaultValues(t *testing.T) {
|
||||
cfg := NewConfig("nonexistent.yaml") // Should use defaults since file doesn't exist
|
||||
|
||||
assert.Equal(t, "0.0.0.0", cfg.Host)
|
||||
assert.Equal(t, 7777, cfg.Port)
|
||||
assert.Equal(t, 10, cfg.KeyLength)
|
||||
assert.Equal(t, "phonetic", cfg.KeyGenerator)
|
||||
assert.Equal(t, "file", cfg.Storage.Type)
|
||||
assert.Equal(t, "data", cfg.Storage.FilePath)
|
||||
assert.Equal(t, "info", cfg.Logging.Level)
|
||||
assert.Equal(t, "text", cfg.Logging.Type)
|
||||
require.Equal(t, "0.0.0.0", cfg.Host)
|
||||
require.Equal(t, 7777, cfg.Port)
|
||||
require.Equal(t, 10, cfg.KeyLength)
|
||||
require.Equal(t, "phonetic", cfg.KeyGenerator)
|
||||
require.Equal(t, "file", cfg.Storage.Type)
|
||||
require.Equal(t, "data", cfg.Storage.FilePath)
|
||||
require.Equal(t, "info", cfg.Logging.Level)
|
||||
}
|
||||
|
||||
func TestNewConfig_OverrideWithEnvVars(t *testing.T) {
|
||||
os.Setenv("HOST", "127.0.0.1")
|
||||
os.Setenv("PORT", "8080")
|
||||
os.Setenv("KEY_LENGTH", "15")
|
||||
os.Setenv("MAX_LENGTH", "5000000")
|
||||
os.Setenv("STORAGE_TYPE", "redis")
|
||||
os.Setenv("STORAGE_HOST", "localhost")
|
||||
os.Setenv("STORAGE_PORT", "6379")
|
||||
os.Setenv("LOGGING_LEVEL", "debug")
|
||||
t.Setenv("HOST", "127.0.0.1")
|
||||
t.Setenv("PORT", "8080")
|
||||
t.Setenv("KEY_LENGTH", "15")
|
||||
t.Setenv("MAX_LENGTH", "5000000")
|
||||
t.Setenv("STORAGE_TYPE", "redis")
|
||||
t.Setenv("STORAGE_HOST", "localhost")
|
||||
t.Setenv("STORAGE_PORT", "6379")
|
||||
t.Setenv("LOGGING_LEVEL", "debug")
|
||||
t.Setenv("RATE_LIMITING_ENABLE", "true")
|
||||
t.Setenv("RATE_LIMITING_LIMIT", "100")
|
||||
|
||||
defer os.Clearenv()
|
||||
|
||||
cfg := NewConfig("nonexistent.yaml") // Load with environment variables
|
||||
|
||||
assert.Equal(t, "127.0.0.1", cfg.Host)
|
||||
assert.Equal(t, 8080, cfg.Port)
|
||||
assert.Equal(t, 15, cfg.KeyLength)
|
||||
assert.Equal(t, 5000000, cfg.MaxLength)
|
||||
assert.Equal(t, "redis", cfg.Storage.Type)
|
||||
assert.Equal(t, "localhost", cfg.Storage.Host)
|
||||
assert.Equal(t, 6379, cfg.Storage.Port)
|
||||
assert.Equal(t, "debug", cfg.Logging.Level)
|
||||
require.Equal(t, "127.0.0.1", cfg.Host)
|
||||
require.Equal(t, 8080, cfg.Port)
|
||||
require.Equal(t, 15, cfg.KeyLength)
|
||||
require.Equal(t, 5000000, cfg.MaxLength)
|
||||
require.Equal(t, "redis", cfg.Storage.Type)
|
||||
require.Equal(t, "localhost", cfg.Storage.Host)
|
||||
require.Equal(t, 6379, cfg.Storage.Port)
|
||||
require.Equal(t, "debug", cfg.Logging.Level)
|
||||
require.Equal(t, true, cfg.RateLimiting.Enable)
|
||||
require.Equal(t, 100, cfg.RateLimiting.Limit)
|
||||
}
|
||||
|
||||
func TestNewConfig_LoadFromYAML(t *testing.T) {
|
||||
@@ -60,21 +63,20 @@ logging:
|
||||
|
||||
// Write to a temporary file
|
||||
tmpFile, err := os.CreateTemp("", "config_test_*.yaml")
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
defer os.Remove(tmpFile.Name())
|
||||
|
||||
_, err = tmpFile.Write([]byte(yamlContent))
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
tmpFile.Close()
|
||||
|
||||
cfg := NewConfig(tmpFile.Name()) // Load config from YAML file
|
||||
|
||||
assert.Equal(t, "192.168.1.1", cfg.Host)
|
||||
assert.Equal(t, 9090, cfg.Port)
|
||||
assert.Equal(t, 20, cfg.KeyLength)
|
||||
assert.Equal(t, "mongodb", cfg.Storage.Type)
|
||||
assert.Equal(t, "mongo.example.com", cfg.Storage.Host)
|
||||
assert.Equal(t, 27017, cfg.Storage.Port)
|
||||
assert.Equal(t, "warn", cfg.Logging.Level)
|
||||
assert.Equal(t, "json", cfg.Logging.Type)
|
||||
require.Equal(t, "192.168.1.1", cfg.Host)
|
||||
require.Equal(t, 9090, cfg.Port)
|
||||
require.Equal(t, 20, cfg.KeyLength)
|
||||
require.Equal(t, "mongodb", cfg.Storage.Type)
|
||||
require.Equal(t, "mongo.example.com", cfg.Storage.Host)
|
||||
require.Equal(t, 27017, cfg.Storage.Port)
|
||||
require.Equal(t, "warn", cfg.Logging.Level)
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"net/http"
|
||||
"os"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/armbian/ansi-hastebin/config"
|
||||
"github.com/armbian/ansi-hastebin/handler"
|
||||
@@ -13,8 +14,9 @@ import (
|
||||
"github.com/armbian/ansi-hastebin/internal/storage"
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/go-chi/chi/v5/middleware"
|
||||
"github.com/go-chi/httprate"
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
type Server struct {
|
||||
@@ -46,6 +48,11 @@ func (s *Server) RegisterRoutes() {
|
||||
s.mux.Use(middleware.Logger)
|
||||
s.mux.Use(middleware.Recoverer)
|
||||
|
||||
// Rate limiter
|
||||
if s.config.RateLimiting.Enable {
|
||||
s.mux.Use(httprate.LimitByRealIP(s.config.RateLimiting.Limit, time.Duration(s.config.RateLimiting.Window)*time.Second))
|
||||
}
|
||||
|
||||
// Register promhttp middleware
|
||||
s.mux.Get("/metrics", promhttp.Handler().ServeHTTP)
|
||||
|
||||
@@ -86,21 +93,21 @@ func (s *Server) RegisterRoutes() {
|
||||
}
|
||||
|
||||
func (s *Server) Start() {
|
||||
logrus.Infof("Starting server on %s", s.server.Addr)
|
||||
log.Info().Str("host", s.config.Host).Int("port", s.config.Port).Msg("Starting server")
|
||||
|
||||
if err := s.server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
logrus.WithError(err).Fatal("Failed to start server")
|
||||
log.Fatal().Err(err).Msg("Failed to start server")
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) Shutdown(ctx context.Context) {
|
||||
logrus.Info("Gracefully shutting down server")
|
||||
log.Info().Msg("Gracefully shutting down server")
|
||||
|
||||
if err := s.storage.Close(); err != nil {
|
||||
logrus.WithError(err).Error("Failed to close storage")
|
||||
log.Error().Err(err).Msg("Failed to close storage")
|
||||
}
|
||||
|
||||
if err := s.server.Shutdown(ctx); err != nil {
|
||||
logrus.WithError(err).Error("Failed to shutdown server")
|
||||
log.Error().Err(err).Msg("Failed to shutdown server")
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user