Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .golangci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ linters:
case:
rules:
json: snake
yaml: snake
disable:
- cyclop # Redundant; cognitive complexity is already tracked by gocognit
- depguard # Import restrictions are not currently required for this internal architecture
Expand Down
23 changes: 23 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ Usage:
sql-tapd [flags]

Flags:
-config path to config file (default: .sql-tap.yaml if exists)
-driver database driver: postgres, mysql, tidb (required)
-listen client listen address (required)
-upstream upstream database address (required)
Expand All @@ -137,12 +138,34 @@ Flags:
-nplus1-threshold N+1 detection threshold (default: 5, 0 to disable)
-nplus1-window N+1 detection time window (default: 1s)
-nplus1-cooldown N+1 alert cooldown per query template (default: 10s)
-slow-threshold slow query threshold (default: 100ms, 0 to disable)
-version show version and exit
```

Set `DATABASE_URL` (or the env var specified by `-dsn-env`) to enable EXPLAIN support. Without it, the proxy still
captures queries but EXPLAIN is disabled.

### Config file

Instead of passing flags on every invocation, you can create a `.sql-tap.yaml` in your project directory:

```yaml
driver: postgres
listen: ":5433"
upstream: "localhost:5432"
grpc: ":9091"
http: ":8080"
dsn_env: DATABASE_URL
slow_threshold: 100ms
nplus1:
threshold: 5
window: 1s
cooldown: 10s
```

sql-tapd automatically loads `.sql-tap.yaml` from the current directory. Use `-config` to specify a different path.
CLI flags override config file values.

### Web UI

Add `--http=:8080` to serve a browser-based viewer:
Expand Down
116 changes: 79 additions & 37 deletions cmd/sql-tapd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
_ "github.com/jackc/pgx/v5/stdlib"

"github.com/mickamy/sql-tap/broker"
"github.com/mickamy/sql-tap/config"
"github.com/mickamy/sql-tap/detect"
"github.com/mickamy/sql-tap/dsn"
"github.com/mickamy/sql-tap/explain"
Expand All @@ -32,11 +33,14 @@ var version = "dev"
func main() {
fs := flag.NewFlagSet("sql-tapd", flag.ExitOnError)
fs.Usage = func() {
fmt.Fprintf(os.Stderr, "sql-tapd — SQL proxy daemon for sql-tap\n\nUsage:\n sql-tapd [flags]\n\nFlags:\n")
fmt.Fprintf(os.Stderr,
"sql-tapd — SQL proxy daemon for sql-tap\n\nUsage:\n sql-tapd [flags]\n\nFlags:\n")
fs.PrintDefaults()
fmt.Fprintf(os.Stderr, "\nEnvironment:\n DATABASE_URL DSN for EXPLAIN queries (read by default via -dsn-env)\n")
fmt.Fprintf(os.Stderr,
"\nEnvironment:\n DATABASE_URL DSN for EXPLAIN queries (read by default via -dsn-env)\n")
}

configPath := fs.String("config", "", "path to config file (default: .sql-tap.yaml)")
driver := fs.String("driver", "", "database driver: postgres, mysql, tidb (required)")
listen := fs.String("listen", "", "client listen address (required)")
upstream := fs.String("upstream", "", "upstream database address (required)")
Comment on lines 44 to 46
Copy link

Copilot AI Mar 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The flag descriptions for -driver/-listen/-upstream still say "(required)", but this PR also allows providing those values via the config file. Consider updating the help text to reflect that they are only required when not set in the config (e.g., "required unless provided by -config / .sql-tap.yaml").

Suggested change
driver := fs.String("driver", "", "database driver: postgres, mysql, tidb (required)")
listen := fs.String("listen", "", "client listen address (required)")
upstream := fs.String("upstream", "", "upstream database address (required)")
driver := fs.String("driver", "", "database driver: postgres, mysql, tidb (required unless provided by -config / .sql-tap.yaml)")
listen := fs.String("listen", "", "client listen address (required unless provided by -config / .sql-tap.yaml)")
upstream := fs.String("upstream", "", "upstream database address (required unless provided by -config / .sql-tap.yaml)")

Copilot uses AI. Check for mistakes.
Expand All @@ -56,24 +60,62 @@ func main() {
return
}

if *driver == "" || *listen == "" || *upstream == "" {
cfg, err := config.Load(*configPath)
if err != nil {
log.Fatal(err)
}

// CLI flags override config file values.
set := flagsSet(fs)
if set["driver"] && *driver != "" {
cfg.Driver = *driver
}
if set["listen"] && *listen != "" {
cfg.Listen = *listen
}
if set["upstream"] && *upstream != "" {
cfg.Upstream = *upstream
}
if set["grpc"] && *grpcAddr != "" {
cfg.GRPC = *grpcAddr
}
if set["dsn-env"] && *dsnEnv != "" {
cfg.DSNEnv = *dsnEnv
}
if set["http"] {
cfg.HTTP = *httpAddr
}
if set["nplus1-threshold"] {
cfg.NPlus1.Threshold = *nplus1Threshold
}
if set["nplus1-window"] {
cfg.NPlus1.Window = *nplus1Window
}
if set["nplus1-cooldown"] {
cfg.NPlus1.Cooldown = *nplus1Cooldown
}
if set["slow-threshold"] {
cfg.SlowThreshold = *slowThreshold
}

if cfg.Driver == "" || cfg.Listen == "" || cfg.Upstream == "" {
fs.Usage()
os.Exit(1)
}

err := run(
*driver, *listen, *upstream, *grpcAddr, *dsnEnv, *httpAddr,
*nplus1Threshold, *nplus1Window, *nplus1Cooldown, *slowThreshold,
)
if err != nil {
if err := run(cfg); err != nil {
log.Fatal(err)
}
}

func run(
driver, listen, upstream, grpcAddr, dsnEnv, httpAddr string,
nplus1Threshold int, nplus1Window, nplus1Cooldown, slowThreshold time.Duration,
) error {
// flagsSet returns the set of flag names explicitly passed on the command line.
func flagsSet(fs *flag.FlagSet) map[string]bool {
m := make(map[string]bool)
fs.Visit(func(f *flag.Flag) { m[f.Name] = true })
return m
}

func run(cfg config.Config) error {
ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
defer stop()

Expand All @@ -82,13 +124,13 @@ func run(

// EXPLAIN client (optional)
var explainClient *explain.Client
if raw := os.Getenv(dsnEnv); raw != "" {
if raw := os.Getenv(cfg.DSNEnv); raw != "" {
db, err := dsn.Open(raw)
if err != nil {
return fmt.Errorf("open db for explain: %w", err)
}
var explainDriver explain.Driver
switch driver {
switch cfg.Driver {
case "mysql":
explainDriver = explain.MySQL
case "tidb":
Expand All @@ -100,32 +142,32 @@ func run(
defer func() { _ = explainClient.Close() }()
log.Printf("EXPLAIN enabled")
} else {
log.Printf("EXPLAIN disabled (%s not set)", dsnEnv)
log.Printf("EXPLAIN disabled (%s not set)", cfg.DSNEnv)
}

// gRPC server
var lc net.ListenConfig
grpcLis, err := lc.Listen(ctx, "tcp", grpcAddr)
grpcLis, err := lc.Listen(ctx, "tcp", cfg.GRPC)
if err != nil {
return fmt.Errorf("listen grpc %s: %w", grpcAddr, err)
return fmt.Errorf("listen grpc %s: %w", cfg.GRPC, err)
}
srv := server.New(b, explainClient)
go func() {
log.Printf("gRPC server listening on %s", grpcAddr)
log.Printf("gRPC server listening on %s", cfg.GRPC)
if err := srv.Serve(grpcLis); err != nil {
log.Printf("grpc serve: %v", err)
}
}()

// HTTP server (optional)
if httpAddr != "" {
httpLis, err := lc.Listen(ctx, "tcp", httpAddr)
if cfg.HTTP != "" {
httpLis, err := lc.Listen(ctx, "tcp", cfg.HTTP)
if err != nil {
return fmt.Errorf("listen http %s: %w", httpAddr, err)
return fmt.Errorf("listen http %s: %w", cfg.HTTP, err)
}
webSrv := web.New(b, explainClient)
go func() {
log.Printf("HTTP server listening on %s", httpAddr)
log.Printf("HTTP server listening on %s", cfg.HTTP)
if err := webSrv.Serve(httpLis); err != nil {
log.Printf("http serve: %v", err)
}
Expand All @@ -139,25 +181,25 @@ func run(

// Proxy
var p proxy.Proxy
switch driver {
switch cfg.Driver {
case "postgres":
p = postgres.New(listen, upstream)
p = postgres.New(cfg.Listen, cfg.Upstream)
case "mysql", "tidb":
p = mysql.New(listen, upstream)
p = mysql.New(cfg.Listen, cfg.Upstream)
default:
return fmt.Errorf("unsupported driver: %s", driver)
return fmt.Errorf("unsupported driver: %s", cfg.Driver)
}

// N+1 detector (optional)
var det *detect.Detector
if nplus1Threshold > 0 {
det = detect.New(nplus1Threshold, nplus1Window, nplus1Cooldown)
if cfg.NPlus1.Threshold > 0 {
det = detect.New(cfg.NPlus1.Threshold, cfg.NPlus1.Window, cfg.NPlus1.Cooldown)
log.Printf("N+1 detection enabled (threshold=%d, window=%s, cooldown=%s)",
nplus1Threshold, nplus1Window, nplus1Cooldown)
cfg.NPlus1.Threshold, cfg.NPlus1.Window, cfg.NPlus1.Cooldown)
}

if slowThreshold > 0 {
log.Printf("slow query detection enabled (threshold=%s)", slowThreshold)
if cfg.SlowThreshold > 0 {
log.Printf("slow query detection enabled (threshold=%s)", cfg.SlowThreshold)
}

go func() {
Expand All @@ -170,17 +212,17 @@ func run(
ev.NPlus1 = r.Matched
if r.Alert != nil {
log.Printf("N+1 detected: %q (%d times in %s)",
r.Alert.Query, r.Alert.Count, nplus1Window)
r.Alert.Query, r.Alert.Count, cfg.NPlus1.Window)
}
}
if slowThreshold > 0 && ev.Duration >= slowThreshold {
if cfg.SlowThreshold > 0 && ev.Duration >= cfg.SlowThreshold {
ev.SlowQuery = true
}
b.Publish(ev)
}
}()

log.Printf("proxying %s -> %s (driver=%s)", listen, upstream, driver)
log.Printf("proxying %s -> %s (driver=%s)", cfg.Listen, cfg.Upstream, cfg.Driver)
if err := p.ListenAndServe(ctx); err != nil {
return fmt.Errorf("proxy: %w", err)
}
Expand All @@ -189,11 +231,11 @@ func run(
return nil
}

func isSelectQuery(op proxy.Op, query string) bool {
func isSelectQuery(op proxy.Op, q string) bool {
switch op {
case proxy.OpQuery, proxy.OpExec, proxy.OpExecute:
q := strings.TrimSpace(query)
return len(q) >= 6 && strings.EqualFold(q[:6], "SELECT")
trimmed := strings.TrimSpace(q)
return len(trimmed) >= 6 && strings.EqualFold(trimmed[:6], "SELECT")
case proxy.OpPrepare, proxy.OpBind, proxy.OpBegin, proxy.OpCommit, proxy.OpRollback:
return false
}
Expand Down
74 changes: 74 additions & 0 deletions config/config.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
package config

import (
"bytes"
"errors"
"fmt"
"os"
"time"

"gopkg.in/yaml.v3"
)

// Config holds the sql-tapd configuration.
type Config struct {
Driver string `yaml:"driver"`
Listen string `yaml:"listen"`
Upstream string `yaml:"upstream"`
GRPC string `yaml:"grpc"`
HTTP string `yaml:"http"`
DSNEnv string `yaml:"dsn_env"`
SlowThreshold time.Duration `yaml:"slow_threshold"`
NPlus1 NPlus1Config `yaml:"nplus1"`
}

// NPlus1Config holds N+1 detection settings.
type NPlus1Config struct {
Threshold int `yaml:"threshold"`
Window time.Duration `yaml:"window"`
Cooldown time.Duration `yaml:"cooldown"`
}

// Default returns a Config with default values.
func Default() Config {
return Config{
GRPC: ":9091",
DSNEnv: "DATABASE_URL",
SlowThreshold: 100 * time.Millisecond,
NPlus1: NPlus1Config{
Threshold: 5,
Window: time.Second,
Cooldown: 10 * time.Second,
},
}
}

// defaultConfigFile is the config file name looked up in the current directory.
const defaultConfigFile = ".sql-tap.yaml"

// Load reads the config file specified by path. If path is empty, it looks for
// the default config file in the current directory. If the default file does
// not exist, it returns the default config without error.
func Load(path string) (Config, error) {
cfg := Default()

if path == "" {
path = defaultConfigFile
if _, err := os.Stat(path); errors.Is(err, os.ErrNotExist) {
return cfg, nil
}
}

data, err := os.ReadFile(path) //nolint:gosec // path is from user-provided flag or a fixed default
if err != nil {
return Config{}, fmt.Errorf("read config %s: %w", path, err)
}

dec := yaml.NewDecoder(bytes.NewReader(data))
dec.KnownFields(true)
if err := dec.Decode(&cfg); err != nil {
return Config{}, fmt.Errorf("parse config %s: %w", path, err)
}

return cfg, nil
}
Loading
Loading