diff --git a/.golangci.yaml b/.golangci.yaml index fb43613..36a1f76 100644 --- a/.golangci.yaml +++ b/.golangci.yaml @@ -25,6 +25,7 @@ linters: case: rules: json: snake + yaml: snake disable: - cyclop # Redundant; cognitive complexity is already tracked by gocognit - depguard # Import restrictions are not currently required for this internal architecture diff --git a/README.md b/README.md index e0a78a1..fb1fbb3 100644 --- a/README.md +++ b/README.md @@ -128,6 +128,7 @@ Usage: sql-tapd [flags] Flags: + -config path to config file (default: .sql-tap.yaml if exists) -driver database driver: postgres, mysql, tidb (required) -listen client listen address (required) -upstream upstream database address (required) @@ -137,12 +138,34 @@ Flags: -nplus1-threshold N+1 detection threshold (default: 5, 0 to disable) -nplus1-window N+1 detection time window (default: 1s) -nplus1-cooldown N+1 alert cooldown per query template (default: 10s) + -slow-threshold slow query threshold (default: 100ms, 0 to disable) -version show version and exit ``` Set `DATABASE_URL` (or the env var specified by `-dsn-env`) to enable EXPLAIN support. Without it, the proxy still captures queries but EXPLAIN is disabled. +### Config file + +Instead of passing flags on every invocation, you can create a `.sql-tap.yaml` in your project directory: + +```yaml +driver: postgres +listen: ":5433" +upstream: "localhost:5432" +grpc: ":9091" +http: ":8080" +dsn_env: DATABASE_URL +slow_threshold: 100ms +nplus1: + threshold: 5 + window: 1s + cooldown: 10s +``` + +sql-tapd automatically loads `.sql-tap.yaml` from the current directory. Use `-config` to specify a different path. +CLI flags override config file values. + ### Web UI Add `--http=:8080` to serve a browser-based viewer: diff --git a/cmd/sql-tapd/main.go b/cmd/sql-tapd/main.go index 39be883..beeb43d 100644 --- a/cmd/sql-tapd/main.go +++ b/cmd/sql-tapd/main.go @@ -16,6 +16,7 @@ import ( _ "github.com/jackc/pgx/v5/stdlib" "github.com/mickamy/sql-tap/broker" + "github.com/mickamy/sql-tap/config" "github.com/mickamy/sql-tap/detect" "github.com/mickamy/sql-tap/dsn" "github.com/mickamy/sql-tap/explain" @@ -32,11 +33,14 @@ var version = "dev" func main() { fs := flag.NewFlagSet("sql-tapd", flag.ExitOnError) fs.Usage = func() { - fmt.Fprintf(os.Stderr, "sql-tapd — SQL proxy daemon for sql-tap\n\nUsage:\n sql-tapd [flags]\n\nFlags:\n") + fmt.Fprintf(os.Stderr, + "sql-tapd — SQL proxy daemon for sql-tap\n\nUsage:\n sql-tapd [flags]\n\nFlags:\n") fs.PrintDefaults() - fmt.Fprintf(os.Stderr, "\nEnvironment:\n DATABASE_URL DSN for EXPLAIN queries (read by default via -dsn-env)\n") + fmt.Fprintf(os.Stderr, + "\nEnvironment:\n DATABASE_URL DSN for EXPLAIN queries (read by default via -dsn-env)\n") } + configPath := fs.String("config", "", "path to config file (default: .sql-tap.yaml)") driver := fs.String("driver", "", "database driver: postgres, mysql, tidb (required)") listen := fs.String("listen", "", "client listen address (required)") upstream := fs.String("upstream", "", "upstream database address (required)") @@ -56,24 +60,62 @@ func main() { return } - if *driver == "" || *listen == "" || *upstream == "" { + cfg, err := config.Load(*configPath) + if err != nil { + log.Fatal(err) + } + + // CLI flags override config file values. + set := flagsSet(fs) + if set["driver"] && *driver != "" { + cfg.Driver = *driver + } + if set["listen"] && *listen != "" { + cfg.Listen = *listen + } + if set["upstream"] && *upstream != "" { + cfg.Upstream = *upstream + } + if set["grpc"] && *grpcAddr != "" { + cfg.GRPC = *grpcAddr + } + if set["dsn-env"] && *dsnEnv != "" { + cfg.DSNEnv = *dsnEnv + } + if set["http"] { + cfg.HTTP = *httpAddr + } + if set["nplus1-threshold"] { + cfg.NPlus1.Threshold = *nplus1Threshold + } + if set["nplus1-window"] { + cfg.NPlus1.Window = *nplus1Window + } + if set["nplus1-cooldown"] { + cfg.NPlus1.Cooldown = *nplus1Cooldown + } + if set["slow-threshold"] { + cfg.SlowThreshold = *slowThreshold + } + + if cfg.Driver == "" || cfg.Listen == "" || cfg.Upstream == "" { fs.Usage() os.Exit(1) } - err := run( - *driver, *listen, *upstream, *grpcAddr, *dsnEnv, *httpAddr, - *nplus1Threshold, *nplus1Window, *nplus1Cooldown, *slowThreshold, - ) - if err != nil { + if err := run(cfg); err != nil { log.Fatal(err) } } -func run( - driver, listen, upstream, grpcAddr, dsnEnv, httpAddr string, - nplus1Threshold int, nplus1Window, nplus1Cooldown, slowThreshold time.Duration, -) error { +// flagsSet returns the set of flag names explicitly passed on the command line. +func flagsSet(fs *flag.FlagSet) map[string]bool { + m := make(map[string]bool) + fs.Visit(func(f *flag.Flag) { m[f.Name] = true }) + return m +} + +func run(cfg config.Config) error { ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) defer stop() @@ -82,13 +124,13 @@ func run( // EXPLAIN client (optional) var explainClient *explain.Client - if raw := os.Getenv(dsnEnv); raw != "" { + if raw := os.Getenv(cfg.DSNEnv); raw != "" { db, err := dsn.Open(raw) if err != nil { return fmt.Errorf("open db for explain: %w", err) } var explainDriver explain.Driver - switch driver { + switch cfg.Driver { case "mysql": explainDriver = explain.MySQL case "tidb": @@ -100,32 +142,32 @@ func run( defer func() { _ = explainClient.Close() }() log.Printf("EXPLAIN enabled") } else { - log.Printf("EXPLAIN disabled (%s not set)", dsnEnv) + log.Printf("EXPLAIN disabled (%s not set)", cfg.DSNEnv) } // gRPC server var lc net.ListenConfig - grpcLis, err := lc.Listen(ctx, "tcp", grpcAddr) + grpcLis, err := lc.Listen(ctx, "tcp", cfg.GRPC) if err != nil { - return fmt.Errorf("listen grpc %s: %w", grpcAddr, err) + return fmt.Errorf("listen grpc %s: %w", cfg.GRPC, err) } srv := server.New(b, explainClient) go func() { - log.Printf("gRPC server listening on %s", grpcAddr) + log.Printf("gRPC server listening on %s", cfg.GRPC) if err := srv.Serve(grpcLis); err != nil { log.Printf("grpc serve: %v", err) } }() // HTTP server (optional) - if httpAddr != "" { - httpLis, err := lc.Listen(ctx, "tcp", httpAddr) + if cfg.HTTP != "" { + httpLis, err := lc.Listen(ctx, "tcp", cfg.HTTP) if err != nil { - return fmt.Errorf("listen http %s: %w", httpAddr, err) + return fmt.Errorf("listen http %s: %w", cfg.HTTP, err) } webSrv := web.New(b, explainClient) go func() { - log.Printf("HTTP server listening on %s", httpAddr) + log.Printf("HTTP server listening on %s", cfg.HTTP) if err := webSrv.Serve(httpLis); err != nil { log.Printf("http serve: %v", err) } @@ -139,25 +181,25 @@ func run( // Proxy var p proxy.Proxy - switch driver { + switch cfg.Driver { case "postgres": - p = postgres.New(listen, upstream) + p = postgres.New(cfg.Listen, cfg.Upstream) case "mysql", "tidb": - p = mysql.New(listen, upstream) + p = mysql.New(cfg.Listen, cfg.Upstream) default: - return fmt.Errorf("unsupported driver: %s", driver) + return fmt.Errorf("unsupported driver: %s", cfg.Driver) } // N+1 detector (optional) var det *detect.Detector - if nplus1Threshold > 0 { - det = detect.New(nplus1Threshold, nplus1Window, nplus1Cooldown) + if cfg.NPlus1.Threshold > 0 { + det = detect.New(cfg.NPlus1.Threshold, cfg.NPlus1.Window, cfg.NPlus1.Cooldown) log.Printf("N+1 detection enabled (threshold=%d, window=%s, cooldown=%s)", - nplus1Threshold, nplus1Window, nplus1Cooldown) + cfg.NPlus1.Threshold, cfg.NPlus1.Window, cfg.NPlus1.Cooldown) } - if slowThreshold > 0 { - log.Printf("slow query detection enabled (threshold=%s)", slowThreshold) + if cfg.SlowThreshold > 0 { + log.Printf("slow query detection enabled (threshold=%s)", cfg.SlowThreshold) } go func() { @@ -170,17 +212,17 @@ func run( ev.NPlus1 = r.Matched if r.Alert != nil { log.Printf("N+1 detected: %q (%d times in %s)", - r.Alert.Query, r.Alert.Count, nplus1Window) + r.Alert.Query, r.Alert.Count, cfg.NPlus1.Window) } } - if slowThreshold > 0 && ev.Duration >= slowThreshold { + if cfg.SlowThreshold > 0 && ev.Duration >= cfg.SlowThreshold { ev.SlowQuery = true } b.Publish(ev) } }() - log.Printf("proxying %s -> %s (driver=%s)", listen, upstream, driver) + log.Printf("proxying %s -> %s (driver=%s)", cfg.Listen, cfg.Upstream, cfg.Driver) if err := p.ListenAndServe(ctx); err != nil { return fmt.Errorf("proxy: %w", err) } @@ -189,11 +231,11 @@ func run( return nil } -func isSelectQuery(op proxy.Op, query string) bool { +func isSelectQuery(op proxy.Op, q string) bool { switch op { case proxy.OpQuery, proxy.OpExec, proxy.OpExecute: - q := strings.TrimSpace(query) - return len(q) >= 6 && strings.EqualFold(q[:6], "SELECT") + trimmed := strings.TrimSpace(q) + return len(trimmed) >= 6 && strings.EqualFold(trimmed[:6], "SELECT") case proxy.OpPrepare, proxy.OpBind, proxy.OpBegin, proxy.OpCommit, proxy.OpRollback: return false } diff --git a/config/config.go b/config/config.go new file mode 100644 index 0000000..76224f6 --- /dev/null +++ b/config/config.go @@ -0,0 +1,74 @@ +package config + +import ( + "bytes" + "errors" + "fmt" + "os" + "time" + + "gopkg.in/yaml.v3" +) + +// Config holds the sql-tapd configuration. +type Config struct { + Driver string `yaml:"driver"` + Listen string `yaml:"listen"` + Upstream string `yaml:"upstream"` + GRPC string `yaml:"grpc"` + HTTP string `yaml:"http"` + DSNEnv string `yaml:"dsn_env"` + SlowThreshold time.Duration `yaml:"slow_threshold"` + NPlus1 NPlus1Config `yaml:"nplus1"` +} + +// NPlus1Config holds N+1 detection settings. +type NPlus1Config struct { + Threshold int `yaml:"threshold"` + Window time.Duration `yaml:"window"` + Cooldown time.Duration `yaml:"cooldown"` +} + +// Default returns a Config with default values. +func Default() Config { + return Config{ + GRPC: ":9091", + DSNEnv: "DATABASE_URL", + SlowThreshold: 100 * time.Millisecond, + NPlus1: NPlus1Config{ + Threshold: 5, + Window: time.Second, + Cooldown: 10 * time.Second, + }, + } +} + +// defaultConfigFile is the config file name looked up in the current directory. +const defaultConfigFile = ".sql-tap.yaml" + +// Load reads the config file specified by path. If path is empty, it looks for +// the default config file in the current directory. If the default file does +// not exist, it returns the default config without error. +func Load(path string) (Config, error) { + cfg := Default() + + if path == "" { + path = defaultConfigFile + if _, err := os.Stat(path); errors.Is(err, os.ErrNotExist) { + return cfg, nil + } + } + + data, err := os.ReadFile(path) //nolint:gosec // path is from user-provided flag or a fixed default + if err != nil { + return Config{}, fmt.Errorf("read config %s: %w", path, err) + } + + dec := yaml.NewDecoder(bytes.NewReader(data)) + dec.KnownFields(true) + if err := dec.Decode(&cfg); err != nil { + return Config{}, fmt.Errorf("parse config %s: %w", path, err) + } + + return cfg, nil +} diff --git a/config/config_test.go b/config/config_test.go new file mode 100644 index 0000000..8f47747 --- /dev/null +++ b/config/config_test.go @@ -0,0 +1,177 @@ +package config_test + +import ( + "os" + "path/filepath" + "testing" + "time" + + "github.com/mickamy/sql-tap/config" +) + +func TestDefault(t *testing.T) { + t.Parallel() + + cfg := config.Default() + + if cfg.GRPC != ":9091" { + t.Errorf("GRPC = %q, want %q", cfg.GRPC, ":9091") + } + if cfg.DSNEnv != "DATABASE_URL" { + t.Errorf("DSNEnv = %q, want %q", cfg.DSNEnv, "DATABASE_URL") + } + if cfg.SlowThreshold != 100*time.Millisecond { + t.Errorf("SlowThreshold = %s, want 100ms", cfg.SlowThreshold) + } + if cfg.NPlus1.Threshold != 5 { + t.Errorf("NPlus1.Threshold = %d, want 5", cfg.NPlus1.Threshold) + } + if cfg.NPlus1.Window != time.Second { + t.Errorf("NPlus1.Window = %s, want 1s", cfg.NPlus1.Window) + } + if cfg.NPlus1.Cooldown != 10*time.Second { + t.Errorf("NPlus1.Cooldown = %s, want 10s", cfg.NPlus1.Cooldown) + } +} + +func TestLoad_ExplicitPath(t *testing.T) { + t.Parallel() + + content := ` +driver: postgres +listen: ":5433" +upstream: "localhost:5432" +grpc: ":9999" +http: ":8080" +dsn_env: MY_DSN +slow_threshold: 200ms +nplus1: + threshold: 10 + window: 2s + cooldown: 30s +` + path := writeTemp(t, content) + + cfg, err := config.Load(path) + if err != nil { + t.Fatalf("Load() error: %v", err) + } + + if cfg.Driver != "postgres" { + t.Errorf("Driver = %q, want %q", cfg.Driver, "postgres") + } + if cfg.Listen != ":5433" { + t.Errorf("Listen = %q, want %q", cfg.Listen, ":5433") + } + if cfg.Upstream != "localhost:5432" { + t.Errorf("Upstream = %q, want %q", cfg.Upstream, "localhost:5432") + } + if cfg.GRPC != ":9999" { + t.Errorf("GRPC = %q, want %q", cfg.GRPC, ":9999") + } + if cfg.HTTP != ":8080" { + t.Errorf("HTTP = %q, want %q", cfg.HTTP, ":8080") + } + if cfg.DSNEnv != "MY_DSN" { + t.Errorf("DSNEnv = %q, want %q", cfg.DSNEnv, "MY_DSN") + } + if cfg.SlowThreshold != 200*time.Millisecond { + t.Errorf("SlowThreshold = %s, want 200ms", cfg.SlowThreshold) + } + if cfg.NPlus1.Threshold != 10 { + t.Errorf("NPlus1.Threshold = %d, want 10", cfg.NPlus1.Threshold) + } + if cfg.NPlus1.Window != 2*time.Second { + t.Errorf("NPlus1.Window = %s, want 2s", cfg.NPlus1.Window) + } + if cfg.NPlus1.Cooldown != 30*time.Second { + t.Errorf("NPlus1.Cooldown = %s, want 30s", cfg.NPlus1.Cooldown) + } +} + +func TestLoad_PartialOverride(t *testing.T) { + t.Parallel() + + content := ` +driver: mysql +listen: ":3307" +upstream: "localhost:3306" +` + path := writeTemp(t, content) + + cfg, err := config.Load(path) + if err != nil { + t.Fatalf("Load() error: %v", err) + } + + if cfg.Driver != "mysql" { + t.Errorf("Driver = %q, want %q", cfg.Driver, "mysql") + } + // Defaults should be preserved for unset fields. + if cfg.GRPC != ":9091" { + t.Errorf("GRPC = %q, want default %q", cfg.GRPC, ":9091") + } + if cfg.NPlus1.Threshold != 5 { + t.Errorf("NPlus1.Threshold = %d, want default 5", cfg.NPlus1.Threshold) + } +} + +func TestLoad_NoDefaultFile(t *testing.T) { //nolint:paralleltest // t.Chdir is incompatible with t.Parallel + t.Chdir(t.TempDir()) + + cfg, err := config.Load("") + if err != nil { + t.Fatalf("Load() error: %v", err) + } + + // Should return defaults. + want := config.Default() + if cfg.GRPC != want.GRPC { + t.Errorf("GRPC = %q, want %q", cfg.GRPC, want.GRPC) + } +} + +func TestLoad_FileNotFound(t *testing.T) { + t.Parallel() + + _, err := config.Load("/nonexistent/path.yaml") + if err == nil { + t.Fatal("expected error for missing explicit path") + } +} + +func TestLoad_InvalidYAML(t *testing.T) { + t.Parallel() + + path := writeTemp(t, "driver: [invalid yaml") + + _, err := config.Load(path) + if err == nil { + t.Fatal("expected error for invalid YAML") + } +} + +func TestLoad_UnknownField(t *testing.T) { + t.Parallel() + + content := ` +driver: postgres +grcp: ":9999" +` + path := writeTemp(t, content) + + _, err := config.Load(path) + if err == nil { + t.Fatal("expected error for unknown field 'grcp'") + } +} + +func writeTemp(t *testing.T, content string) string { + t.Helper() + dir := t.TempDir() + path := filepath.Join(dir, "config.yaml") + if err := os.WriteFile(path, []byte(content), 0o600); err != nil { + t.Fatal(err) + } + return path +}