From 8ddbfd92e45d15bd7d9d69c042154564c88af86e Mon Sep 17 00:00:00 2001 From: Tetsuro Mikami Date: Wed, 11 Mar 2026 10:17:47 +0900 Subject: [PATCH 1/6] feat: add CI mode to sql-tap client for automated N+1 and slow query detection --- ci/ci.go | 192 ++++++++++++++++++++++++++++++++++++++++++++ ci/ci_test.go | 218 ++++++++++++++++++++++++++++++++++++++++++++++++++ main.go | 34 +++++++- 3 files changed, 443 insertions(+), 1 deletion(-) create mode 100644 ci/ci.go create mode 100644 ci/ci_test.go diff --git a/ci/ci.go b/ci/ci.go new file mode 100644 index 0000000..e80e958 --- /dev/null +++ b/ci/ci.go @@ -0,0 +1,192 @@ +package ci + +import ( + "context" + "errors" + "fmt" + "io" + "sort" + "strings" + "time" + + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/status" + + tapv1 "github.com/mickamy/sql-tap/gen/tap/v1" +) + +// Result holds the CI run outcome. +type Result struct { + TotalQueries int + Problems []Problem +} + +// HasProblems reports whether any issues were detected. +func (r Result) HasProblems() bool { + return len(r.Problems) > 0 +} + +// ProblemKind categorizes a detected issue. +type ProblemKind string + +const ( + ProblemNPlus1 ProblemKind = "N+1" + ProblemSlowQuery ProblemKind = "SLOW" +) + +// Problem describes a single detected issue. +type Problem struct { + Kind ProblemKind + Query string + Count int + // AvgDuration is set only for ProblemSlowQuery. + AvgDuration time.Duration +} + +// Report formats the result as a human-readable string. +func (r Result) Report() string { + var b strings.Builder + b.WriteString("sql-tap CI Report\n") + b.WriteString("=================\n") + fmt.Fprintf(&b, "Captured: %d queries\n", r.TotalQueries) + + if !r.HasProblems() { + b.WriteString("\nNo problems found.\n") + return b.String() + } + + b.WriteString("\nProblems found:\n") + for _, p := range r.Problems { + switch p.Kind { + case ProblemNPlus1: + fmt.Fprintf(&b, " [N+1] %s (detected %d times)\n", p.Query, p.Count) + case ProblemSlowQuery: + avg := p.AvgDuration.Truncate(time.Millisecond) + fmt.Fprintf(&b, " [SLOW] %s (avg %s, %d occurrences)\n", p.Query, avg, p.Count) + } + } + + fmt.Fprintf(&b, "\nExit: 1 (%d problems found)\n", len(r.Problems)) + return b.String() +} + +// Run connects to the gRPC server at addr, collects query events until ctx is +// cancelled, and returns the aggregated result. +func Run(ctx context.Context, addr string) (Result, error) { + conn, err := grpc.NewClient(addr, grpc.WithTransportCredentials(insecure.NewCredentials())) + if err != nil { + return Result{}, fmt.Errorf("dial %s: %w", addr, err) + } + defer func() { _ = conn.Close() }() + + client := tapv1.NewTapServiceClient(conn) + stream, err := client.Watch(ctx, &tapv1.WatchRequest{}) + if err != nil { + return Result{}, fmt.Errorf("watch %s: %w", addr, err) + } + + return collect(ctx, stream) +} + +func collect(ctx context.Context, stream tapv1.TapService_WatchClient) (Result, error) { + var events []*tapv1.QueryEvent + + for { + resp, err := stream.Recv() + if err != nil { + if isStreamDone(ctx, err) { + break + } + return Result{}, fmt.Errorf("recv: %w", err) + } + events = append(events, resp.GetEvent()) + } + + return Aggregate(events), nil +} + +func isStreamDone(ctx context.Context, err error) bool { + if errors.Is(err, io.EOF) { + return true + } + if ctx.Err() != nil { + return true + } + // gRPC wraps context errors in status; unwrap and check. + if s, ok := status.FromError(err); ok { + msg := s.Message() + return strings.Contains(msg, "context canceled") || + strings.Contains(msg, "context deadline exceeded") + } + return false +} + +// Aggregate computes the CI result from collected events. +func Aggregate(events []*tapv1.QueryEvent) Result { + result := Result{TotalQueries: len(events)} + + type stats struct { + nplus1Count int + slowCount int + totalDur time.Duration + } + + grouped := make(map[string]*stats) + + for _, e := range events { + if !e.GetNPlus_1() && !e.GetSlowQuery() { + continue + } + q := normalizedOrRaw(e) + s, ok := grouped[q] + if !ok { + s = &stats{} + grouped[q] = s + } + if e.GetNPlus_1() { + s.nplus1Count++ + } + if e.GetSlowQuery() { + s.slowCount++ + if d := e.GetDuration(); d != nil { + s.totalDur += d.AsDuration() + } + } + } + + for q, s := range grouped { + if s.nplus1Count > 0 { + result.Problems = append(result.Problems, Problem{ + Kind: ProblemNPlus1, + Query: q, + Count: s.nplus1Count, + }) + } + if s.slowCount > 0 { + avg := s.totalDur / time.Duration(s.slowCount) + result.Problems = append(result.Problems, Problem{ + Kind: ProblemSlowQuery, + Query: q, + Count: s.slowCount, + AvgDuration: avg, + }) + } + } + + sort.Slice(result.Problems, func(i, j int) bool { + if result.Problems[i].Kind != result.Problems[j].Kind { + return result.Problems[i].Kind < result.Problems[j].Kind + } + return result.Problems[i].Count > result.Problems[j].Count + }) + + return result +} + +func normalizedOrRaw(e *tapv1.QueryEvent) string { + if nq := e.GetNormalizedQuery(); nq != "" { + return nq + } + return e.GetQuery() +} diff --git a/ci/ci_test.go b/ci/ci_test.go new file mode 100644 index 0000000..25736d2 --- /dev/null +++ b/ci/ci_test.go @@ -0,0 +1,218 @@ +package ci_test + +import ( + "testing" + "time" + + "google.golang.org/protobuf/types/known/durationpb" + + "github.com/mickamy/sql-tap/ci" + tapv1 "github.com/mickamy/sql-tap/gen/tap/v1" +) + +func TestResult_HasProblems(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + r ci.Result + want bool + }{ + { + name: "no problems", + r: ci.Result{TotalQueries: 10}, + want: false, + }, + { + name: "with problems", + r: ci.Result{ + TotalQueries: 10, + Problems: []ci.Problem{ + {Kind: ci.ProblemNPlus1, Query: "SELECT 1", Count: 5}, + }, + }, + want: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + if got := tt.r.HasProblems(); got != tt.want { + t.Errorf("HasProblems() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestResult_Report_NoProblems(t *testing.T) { + t.Parallel() + + r := ci.Result{TotalQueries: 42} + report := r.Report() + + assertContains(t, report, "Captured: 42 queries") + assertContains(t, report, "No problems found.") +} + +func TestResult_Report_WithProblems(t *testing.T) { + t.Parallel() + + r := ci.Result{ + TotalQueries: 100, + Problems: []ci.Problem{ + {Kind: ci.ProblemNPlus1, Query: "SELECT * FROM comments WHERE post_id = ?", Count: 12}, + {Kind: ci.ProblemSlowQuery, Query: "SELECT * FROM users JOIN ...", Count: 3, AvgDuration: 523 * time.Millisecond}, + }, + } + report := r.Report() + + assertContains(t, report, "Captured: 100 queries") + assertContains(t, report, "[N+1]") + assertContains(t, report, "detected 12 times") + assertContains(t, report, "[SLOW]") + assertContains(t, report, "avg 523ms") + assertContains(t, report, "2 problems found") +} + +func TestAggregate(t *testing.T) { + t.Parallel() + + events := []*tapv1.QueryEvent{ + {Query: "SELECT 1", NormalizedQuery: "SELECT ?"}, + {Query: "SELECT * FROM users WHERE id = 1", NormalizedQuery: "SELECT * FROM users WHERE id = ?", NPlus_1: true}, + {Query: "SELECT * FROM users WHERE id = 2", NormalizedQuery: "SELECT * FROM users WHERE id = ?", NPlus_1: true}, + {Query: "SELECT * FROM users WHERE id = 3", NormalizedQuery: "SELECT * FROM users WHERE id = ?", NPlus_1: true}, + { + Query: "SELECT * FROM posts JOIN comments ON ...", + NormalizedQuery: "SELECT * FROM posts JOIN comments ON ...", + SlowQuery: true, + Duration: durationpb.New(200 * time.Millisecond), + }, + { + Query: "SELECT * FROM posts JOIN comments ON ...", + NormalizedQuery: "SELECT * FROM posts JOIN comments ON ...", + SlowQuery: true, + Duration: durationpb.New(400 * time.Millisecond), + }, + } + + result := ci.Aggregate(events) + + if result.TotalQueries != 6 { + t.Errorf("TotalQueries = %d, want 6", result.TotalQueries) + } + if len(result.Problems) != 2 { + t.Fatalf("len(Problems) = %d, want 2", len(result.Problems)) + } + + // N+1 problems are sorted first (N+1 < SLOW lexically). + nplus1 := result.Problems[0] + if nplus1.Kind != ci.ProblemNPlus1 { + t.Errorf("Problems[0].Kind = %s, want N+1", nplus1.Kind) + } + if nplus1.Count != 3 { + t.Errorf("Problems[0].Count = %d, want 3", nplus1.Count) + } + + slow := result.Problems[1] + if slow.Kind != ci.ProblemSlowQuery { + t.Errorf("Problems[1].Kind = %s, want SLOW", slow.Kind) + } + if slow.Count != 2 { + t.Errorf("Problems[1].Count = %d, want 2", slow.Count) + } + if slow.AvgDuration != 300*time.Millisecond { + t.Errorf("Problems[1].AvgDuration = %s, want 300ms", slow.AvgDuration) + } +} + +func TestAggregate_NoProblemEvents(t *testing.T) { + t.Parallel() + + events := []*tapv1.QueryEvent{ + {Query: "SELECT 1"}, + {Query: "INSERT INTO users VALUES (1)"}, + } + + result := ci.Aggregate(events) + + if result.TotalQueries != 2 { + t.Errorf("TotalQueries = %d, want 2", result.TotalQueries) + } + if result.HasProblems() { + t.Error("expected no problems") + } +} + +func TestAggregate_Empty(t *testing.T) { + t.Parallel() + + result := ci.Aggregate(nil) + + if result.TotalQueries != 0 { + t.Errorf("TotalQueries = %d, want 0", result.TotalQueries) + } + if result.HasProblems() { + t.Error("expected no problems") + } +} + +func TestAggregate_BothNPlus1AndSlow(t *testing.T) { + t.Parallel() + + events := []*tapv1.QueryEvent{ + { + Query: "SELECT * FROM users WHERE id = ?", + NormalizedQuery: "SELECT * FROM users WHERE id = ?", + NPlus_1: true, + SlowQuery: true, + Duration: durationpb.New(150 * time.Millisecond), + }, + } + + result := ci.Aggregate(events) + + if len(result.Problems) != 2 { + t.Fatalf("len(Problems) = %d, want 2 (one N+1, one SLOW)", len(result.Problems)) + } +} + +func TestAggregate_UsesRawQueryWhenNormalizedEmpty(t *testing.T) { + t.Parallel() + + const rawQuery = "SELECT id, name FROM users" + events := []*tapv1.QueryEvent{ + {Query: rawQuery, NPlus_1: true}, + {Query: rawQuery, NPlus_1: true}, + } + + result := ci.Aggregate(events) + + if len(result.Problems) != 1 { + t.Fatalf("len(Problems) = %d, want 1", len(result.Problems)) + } + if result.Problems[0].Query != rawQuery { + t.Errorf("Query = %q, want %q", result.Problems[0].Query, rawQuery) + } +} + +func assertContains(t *testing.T, s, substr string) { + t.Helper() + if !contains(s, substr) { + t.Errorf("expected report to contain %q, got:\n%s", substr, s) + } +} + +func contains(s, substr string) bool { + return len(s) >= len(substr) && searchString(s, substr) +} + +func searchString(s, substr string) bool { + for i := 0; i <= len(s)-len(substr); i++ { + if s[i:i+len(substr)] == substr { + return true + } + } + return false +} diff --git a/main.go b/main.go index 9a10798..8727896 100644 --- a/main.go +++ b/main.go @@ -1,12 +1,16 @@ package main import ( + "context" "flag" "fmt" "os" + "os/signal" + "syscall" tea "github.com/charmbracelet/bubbletea" + "github.com/mickamy/sql-tap/ci" "github.com/mickamy/sql-tap/tui" ) @@ -20,6 +24,7 @@ func main() { } showVersion := fs.Bool("version", false, "show version and exit") + ciMode := fs.Bool("ci", false, "run in CI mode: collect events until SIGTERM/SIGINT, then report and exit") _ = fs.Parse(os.Args[1:]) @@ -33,7 +38,12 @@ func main() { os.Exit(1) } - monitor(fs.Arg(0)) + addr := fs.Arg(0) + if *ciMode { + runCI(addr) + } else { + monitor(addr) + } } func monitor(addr string) { @@ -44,3 +54,25 @@ func monitor(addr string) { os.Exit(1) } } + +func runCI(addr string) { + os.Exit(runCIExitCode(addr)) +} + +func runCIExitCode(addr string) int { + ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) + defer stop() + + result, err := ci.Run(ctx, addr) + if err != nil { + fmt.Fprintf(os.Stderr, "Error: %v\n", err) + return 1 + } + + fmt.Fprint(os.Stderr, result.Report()) + + if result.HasProblems() { + return 1 + } + return 0 +} From 6096baaaa0e55ee826bd02787b093925d7eee6b4 Mon Sep 17 00:00:00 2001 From: Tetsuro Mikami Date: Wed, 11 Mar 2026 10:23:22 +0900 Subject: [PATCH 2/6] lint: remove unused nolint directives in proxy tests --- proxy/mysql/proxy_test.go | 2 +- proxy/postgres/proxy_test.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/proxy/mysql/proxy_test.go b/proxy/mysql/proxy_test.go index fdee56d..89af258 100644 --- a/proxy/mysql/proxy_test.go +++ b/proxy/mysql/proxy_test.go @@ -64,7 +64,7 @@ func startProxy(t *testing.T, upstream string) (*mproxy.Proxy, string) { _ = lis.Close() p := mproxy.New(addr, upstream) - ctx, cancel := context.WithCancel(t.Context()) //nolint:gosec // cancel is deferred below via t.Cleanup + ctx, cancel := context.WithCancel(t.Context()) go func() { if err := p.ListenAndServe(ctx); err != nil { diff --git a/proxy/postgres/proxy_test.go b/proxy/postgres/proxy_test.go index 33af736..97292aa 100644 --- a/proxy/postgres/proxy_test.go +++ b/proxy/postgres/proxy_test.go @@ -70,7 +70,7 @@ func startProxy(t *testing.T, upstream string) (*pproxy.Proxy, string) { _ = lis.Close() p := pproxy.New(addr, upstream) - ctx, cancel := context.WithCancel(t.Context()) //nolint:gosec // cancel is deferred below via t.Cleanup + ctx, cancel := context.WithCancel(t.Context()) go func() { if err := p.ListenAndServe(ctx); err != nil { From af8e09b8ceecc43bf0cb26f51285ec4adc80d7ab Mon Sep 17 00:00:00 2001 From: Tetsuro Mikami Date: Wed, 11 Mar 2026 10:25:43 +0900 Subject: [PATCH 3/6] fix: use gRPC status codes instead of string matching in isStreamDone --- ci/ci.go | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/ci/ci.go b/ci/ci.go index e80e958..9c78d7e 100644 --- a/ci/ci.go +++ b/ci/ci.go @@ -10,6 +10,7 @@ import ( "time" "google.golang.org/grpc" + "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/status" @@ -113,13 +114,8 @@ func isStreamDone(ctx context.Context, err error) bool { if ctx.Err() != nil { return true } - // gRPC wraps context errors in status; unwrap and check. - if s, ok := status.FromError(err); ok { - msg := s.Message() - return strings.Contains(msg, "context canceled") || - strings.Contains(msg, "context deadline exceeded") - } - return false + code := status.Code(err) + return code == codes.Canceled || code == codes.DeadlineExceeded } // Aggregate computes the CI result from collected events. From 30c2562041dd70580e11b02ab412c47004d95315 Mon Sep 17 00:00:00 2001 From: Tetsuro Mikami Date: Wed, 11 Mar 2026 10:26:45 +0900 Subject: [PATCH 4/6] refactor: use streaming aggregation instead of storing all events in memory --- ci/ci.go | 93 +++++++++++++++++++++++++++++++++----------------------- 1 file changed, 55 insertions(+), 38 deletions(-) diff --git a/ci/ci.go b/ci/ci.go index 9c78d7e..b197e06 100644 --- a/ci/ci.go +++ b/ci/ci.go @@ -91,7 +91,7 @@ func Run(ctx context.Context, addr string) (Result, error) { } func collect(ctx context.Context, stream tapv1.TapService_WatchClient) (Result, error) { - var events []*tapv1.QueryEvent + a := newAggregator() for { resp, err := stream.Recv() @@ -101,10 +101,10 @@ func collect(ctx context.Context, stream tapv1.TapService_WatchClient) (Result, } return Result{}, fmt.Errorf("recv: %w", err) } - events = append(events, resp.GetEvent()) + a.add(resp.GetEvent()) } - return Aggregate(events), nil + return a.result(), nil } func isStreamDone(ctx context.Context, err error) bool { @@ -118,42 +118,49 @@ func isStreamDone(ctx context.Context, err error) bool { return code == codes.Canceled || code == codes.DeadlineExceeded } -// Aggregate computes the CI result from collected events. -func Aggregate(events []*tapv1.QueryEvent) Result { - result := Result{TotalQueries: len(events)} +type queryStats struct { + nplus1Count int + slowCount int + totalDur time.Duration +} - type stats struct { - nplus1Count int - slowCount int - totalDur time.Duration - } +type aggregator struct { + total int + grouped map[string]*queryStats +} - grouped := make(map[string]*stats) +func newAggregator() *aggregator { + return &aggregator{grouped: make(map[string]*queryStats)} +} - for _, e := range events { - if !e.GetNPlus_1() && !e.GetSlowQuery() { - continue - } - q := normalizedOrRaw(e) - s, ok := grouped[q] - if !ok { - s = &stats{} - grouped[q] = s - } - if e.GetNPlus_1() { - s.nplus1Count++ - } - if e.GetSlowQuery() { - s.slowCount++ - if d := e.GetDuration(); d != nil { - s.totalDur += d.AsDuration() - } +func (a *aggregator) add(e *tapv1.QueryEvent) { + a.total++ + if !e.GetNPlus_1() && !e.GetSlowQuery() { + return + } + q := normalizedOrRaw(e) + s, ok := a.grouped[q] + if !ok { + s = &queryStats{} + a.grouped[q] = s + } + if e.GetNPlus_1() { + s.nplus1Count++ + } + if e.GetSlowQuery() { + s.slowCount++ + if d := e.GetDuration(); d != nil { + s.totalDur += d.AsDuration() } } +} + +func (a *aggregator) result() Result { + r := Result{TotalQueries: a.total} - for q, s := range grouped { + for q, s := range a.grouped { if s.nplus1Count > 0 { - result.Problems = append(result.Problems, Problem{ + r.Problems = append(r.Problems, Problem{ Kind: ProblemNPlus1, Query: q, Count: s.nplus1Count, @@ -161,7 +168,7 @@ func Aggregate(events []*tapv1.QueryEvent) Result { } if s.slowCount > 0 { avg := s.totalDur / time.Duration(s.slowCount) - result.Problems = append(result.Problems, Problem{ + r.Problems = append(r.Problems, Problem{ Kind: ProblemSlowQuery, Query: q, Count: s.slowCount, @@ -170,14 +177,24 @@ func Aggregate(events []*tapv1.QueryEvent) Result { } } - sort.Slice(result.Problems, func(i, j int) bool { - if result.Problems[i].Kind != result.Problems[j].Kind { - return result.Problems[i].Kind < result.Problems[j].Kind + sort.Slice(r.Problems, func(i, j int) bool { + if r.Problems[i].Kind != r.Problems[j].Kind { + return r.Problems[i].Kind < r.Problems[j].Kind } - return result.Problems[i].Count > result.Problems[j].Count + return r.Problems[i].Count > r.Problems[j].Count }) - return result + return r +} + +// Aggregate computes the CI result from the given events. +// Intended for testing; the streaming path uses aggregator directly. +func Aggregate(events []*tapv1.QueryEvent) Result { + a := newAggregator() + for _, e := range events { + a.add(e) + } + return a.result() } func normalizedOrRaw(e *tapv1.QueryEvent) string { From fd83936af7e914afd457f179d763d743708a7c65 Mon Sep 17 00:00:00 2001 From: Tetsuro Mikami Date: Wed, 11 Mar 2026 10:34:08 +0900 Subject: [PATCH 5/6] fix: address review comments on CI mode - Update doc comment and help text to mention stream-end termination - Print CI report to stdout instead of stderr - Add default case in Report for unknown ProblemKind --- ci/ci.go | 4 +++- main.go | 5 +++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/ci/ci.go b/ci/ci.go index b197e06..e06b311 100644 --- a/ci/ci.go +++ b/ci/ci.go @@ -65,6 +65,8 @@ func (r Result) Report() string { case ProblemSlowQuery: avg := p.AvgDuration.Truncate(time.Millisecond) fmt.Fprintf(&b, " [SLOW] %s (avg %s, %d occurrences)\n", p.Query, avg, p.Count) + default: + fmt.Fprintf(&b, " [%s] %s (%d occurrences)\n", string(p.Kind), p.Query, p.Count) } } @@ -73,7 +75,7 @@ func (r Result) Report() string { } // Run connects to the gRPC server at addr, collects query events until ctx is -// cancelled, and returns the aggregated result. +// cancelled or the server closes the stream, and returns the aggregated result. func Run(ctx context.Context, addr string) (Result, error) { conn, err := grpc.NewClient(addr, grpc.WithTransportCredentials(insecure.NewCredentials())) if err != nil { diff --git a/main.go b/main.go index 8727896..ebe138b 100644 --- a/main.go +++ b/main.go @@ -24,7 +24,8 @@ func main() { } showVersion := fs.Bool("version", false, "show version and exit") - ciMode := fs.Bool("ci", false, "run in CI mode: collect events until SIGTERM/SIGINT, then report and exit") + ciMode := fs.Bool("ci", false, + "run in CI mode: collect events until SIGTERM/SIGINT or stream ends, then report and exit") _ = fs.Parse(os.Args[1:]) @@ -69,7 +70,7 @@ func runCIExitCode(addr string) int { return 1 } - fmt.Fprint(os.Stderr, result.Report()) + fmt.Fprint(os.Stdout, result.Report()) if result.HasProblems() { return 1 From 4384ae3d96c654d5e6cdb4fa1ab859670e4e7327 Mon Sep 17 00:00:00 2001 From: Tetsuro Mikami Date: Wed, 11 Mar 2026 10:37:00 +0900 Subject: [PATCH 6/6] test: add integration tests for ci.Run with in-process gRPC server --- ci/ci_test.go | 116 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 116 insertions(+) diff --git a/ci/ci_test.go b/ci/ci_test.go index 25736d2..0b0dd54 100644 --- a/ci/ci_test.go +++ b/ci/ci_test.go @@ -1,13 +1,18 @@ package ci_test import ( + "context" + "net" "testing" "time" "google.golang.org/protobuf/types/known/durationpb" + "github.com/mickamy/sql-tap/broker" "github.com/mickamy/sql-tap/ci" tapv1 "github.com/mickamy/sql-tap/gen/tap/v1" + "github.com/mickamy/sql-tap/proxy" + "github.com/mickamy/sql-tap/server" ) func TestResult_HasProblems(t *testing.T) { @@ -197,6 +202,117 @@ func TestAggregate_UsesRawQueryWhenNormalizedEmpty(t *testing.T) { } } +func startServer(t *testing.T, b *broker.Broker) string { + t.Helper() + + var lc net.ListenConfig + lis, err := lc.Listen(t.Context(), "tcp", "localhost:0") + if err != nil { + t.Fatal(err) + } + + srv := server.New(b, nil) + go func() { _ = srv.Serve(lis) }() + t.Cleanup(srv.Stop) + + return lis.Addr().String() +} + +func TestRun_AggregatesEventsOnContextCancel(t *testing.T) { + t.Parallel() + + b := broker.New(8) + addr := startServer(t, b) + + ctx, cancel := context.WithCancel(t.Context()) + + done := make(chan ci.Result, 1) + errc := make(chan error, 1) + go func() { + result, err := ci.Run(ctx, addr) + if err != nil { + errc <- err + return + } + done <- result + }() + + // Wait for subscription to register. + time.Sleep(50 * time.Millisecond) + + b.Publish(proxy.Event{ID: "1", Op: proxy.OpQuery, Query: "SELECT 1"}) + b.Publish(proxy.Event{ + ID: "2", Op: proxy.OpQuery, Query: "SELECT id FROM users WHERE id = 1", + NPlus1: true, + NormalizedQuery: "SELECT id FROM users WHERE id = ?", + }) + b.Publish(proxy.Event{ + ID: "3", Op: proxy.OpQuery, Query: "SELECT id FROM users WHERE id = 2", + NPlus1: true, + NormalizedQuery: "SELECT id FROM users WHERE id = ?", + }) + + // Give events time to arrive, then cancel. + time.Sleep(50 * time.Millisecond) + cancel() + + select { + case result := <-done: + if result.TotalQueries != 3 { + t.Errorf("TotalQueries = %d, want 3", result.TotalQueries) + } + if !result.HasProblems() { + t.Error("expected problems") + } + case err := <-errc: + t.Fatalf("Run returned error: %v", err) + case <-time.After(5 * time.Second): + t.Fatal("timed out waiting for Run to return") + } +} + +func TestRun_NoProblemEvents(t *testing.T) { + t.Parallel() + + b := broker.New(8) + addr := startServer(t, b) + + ctx, cancel := context.WithCancel(t.Context()) + + done := make(chan ci.Result, 1) + errc := make(chan error, 1) + go func() { + result, err := ci.Run(ctx, addr) + if err != nil { + errc <- err + return + } + done <- result + }() + + time.Sleep(50 * time.Millisecond) + + b.Publish(proxy.Event{ID: "1", Op: proxy.OpQuery, Query: "SELECT 1"}) + b.Publish(proxy.Event{ID: "2", Op: proxy.OpExec, Query: "INSERT INTO t VALUES (1)"}) + + time.Sleep(50 * time.Millisecond) + cancel() + + select { + case result := <-done: + if result.TotalQueries != 2 { + t.Errorf("TotalQueries = %d, want 2", result.TotalQueries) + } + if result.HasProblems() { + t.Error("expected no problems") + } + case err := <-errc: + t.Fatalf("Run returned error: %v", err) + case <-time.After(5 * time.Second): + t.Fatal("timed out waiting for Run to return") + } +} + func assertContains(t *testing.T, s, substr string) { t.Helper() if !contains(s, substr) {