diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index a0aa33a..2a36764 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -11,7 +11,7 @@ on: - "**" permissions: - contents: write + contents: read jobs: lint: @@ -21,22 +21,19 @@ jobs: - uses: actions/setup-go@v5 with: - go-version: "1.25" + go-version: "1.26" - name: golangci-lint uses: golangci/golangci-lint-action@v9 test: runs-on: ubuntu-latest - strategy: - matrix: - go-version: [ "1.25", "1.26" ] steps: - uses: actions/checkout@v6 - uses: actions/setup-go@v5 with: - go-version: ${{ matrix.go-version }} + go-version: "1.26" - name: Run tests run: make test diff --git a/README.md b/README.md index a5d54cf..1580aec 100644 --- a/README.md +++ b/README.md @@ -65,26 +65,25 @@ Then reference it in `devenv.nix`: ### Docker -**PostgreSQL** - ```dockerfile -FROM postgres:18-alpine +FROM alpine:3 ARG SQL_TAP_VERSION=0.0.1 ARG TARGETARCH ADD https://github.com/mickamy/sql-tap/releases/download/v${SQL_TAP_VERSION}/sql-tap_${SQL_TAP_VERSION}_linux_${TARGETARCH}.tar.gz /tmp/sql-tap.tar.gz RUN tar -xzf /tmp/sql-tap.tar.gz -C /usr/local/bin sql-tapd && rm /tmp/sql-tap.tar.gz -ENTRYPOINT ["sql-tapd", "--driver=postgres", "--listen=:5433", "--upstream=localhost:5432", "--grpc=:9091"] +ENTRYPOINT ["sql-tapd"] ``` -**MySQL** +Run as a sidecar alongside your database: -```dockerfile -FROM mysql:8 -ARG SQL_TAP_VERSION=0.0.1 -ARG TARGETARCH -ADD https://github.com/mickamy/sql-tap/releases/download/v${SQL_TAP_VERSION}/sql-tap_${SQL_TAP_VERSION}_linux_${TARGETARCH}.tar.gz /tmp/sql-tap.tar.gz -RUN tar -xzf /tmp/sql-tap.tar.gz -C /usr/local/bin sql-tapd && rm /tmp/sql-tap.tar.gz -ENTRYPOINT ["sql-tapd", "--driver=mysql", "--listen=:3307", "--upstream=localhost:3306", "--grpc=:9091"] +```bash +# PostgreSQL +docker run --rm --network=host sql-tap \ + --driver=postgres --listen=:5433 --upstream=localhost:5432 --grpc=:9091 + +# MySQL +docker run --rm --network=host sql-tap \ + --driver=mysql --listen=:3307 --upstream=localhost:3306 --grpc=:9091 ``` ## Quick start diff --git a/go.mod b/go.mod index f11a6ab..edebbe0 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/mickamy/sql-tap -go 1.25.0 +go 1.26.1 require ( github.com/alecthomas/chroma/v2 v2.23.1 diff --git a/proxy/mysql/proxy_test.go b/proxy/mysql/proxy_test.go index 89af258..fdee56d 100644 --- a/proxy/mysql/proxy_test.go +++ b/proxy/mysql/proxy_test.go @@ -64,7 +64,7 @@ func startProxy(t *testing.T, upstream string) (*mproxy.Proxy, string) { _ = lis.Close() p := mproxy.New(addr, upstream) - ctx, cancel := context.WithCancel(t.Context()) + ctx, cancel := context.WithCancel(t.Context()) //nolint:gosec // cancel is deferred below via t.Cleanup go func() { if err := p.ListenAndServe(ctx); err != nil { diff --git a/proxy/postgres/conn.go b/proxy/postgres/conn.go index 0f7deb4..05f23ef 100644 --- a/proxy/postgres/conn.go +++ b/proxy/postgres/conn.go @@ -42,12 +42,22 @@ type conn struct { events chan<- proxy.Event // Extended query state. + // preparedStmts is only accessed by the client→upstream goroutine. preparedStmts map[string]string // stmt name -> query preparedStmtOIDs map[string][]uint32 // stmt name -> parameter OIDs lastParse string // query from most recent Parse lastParamOIDs []uint32 // parameter OIDs from most recent Parse lastBindArgs []string // args from most recent Bind lastBindStmt string // stmt name from most recent Bind + // pendingDescribes is a FIFO queue of statement names from Describe('S') + // messages. ParameterDescription responses arrive in the same order, so + // we pop from the front to match each response to its request. + pendingDescribes []string + + // stmtMu protects OID-related fields that are written by + // handleParameterDescription (upstream→client goroutine) and read by + // handleBind / written by handleParse (client→upstream goroutine). + stmtMu sync.Mutex // Transaction tracking. activeTxID string @@ -275,6 +285,8 @@ func (c *conn) captureClientMsg(msg pgproto.FrontendMessage) { c.handleSimpleQuery(m) case *pgproto.Parse: c.handleParse(m) + case *pgproto.Describe: + c.handleDescribe(m) case *pgproto.Bind: c.handleBind(m) case *pgproto.Execute: @@ -284,10 +296,14 @@ func (c *conn) captureClientMsg(msg pgproto.FrontendMessage) { func (c *conn) captureUpstreamMsg(msg pgproto.BackendMessage) { switch m := msg.(type) { + case *pgproto.ParameterDescription: + c.handleParameterDescription(m) case *pgproto.CommandComplete: c.handleCommandComplete(m) case *pgproto.ErrorResponse: c.handleErrorResponse(m) + case *pgproto.ReadyForQuery: + c.drainPendingDescribes() } } @@ -309,21 +325,69 @@ func (c *conn) handleSimpleQuery(m *pgproto.Query) { func (c *conn) handleParse(m *pgproto.Parse) { c.lastParse = m.Query + c.stmtMu.Lock() c.lastParamOIDs = m.ParameterOIDs if m.Name != "" { - c.preparedStmts[m.Name] = m.Query c.preparedStmtOIDs[m.Name] = m.ParameterOIDs } + c.stmtMu.Unlock() + if m.Name != "" { + c.preparedStmts[m.Name] = m.Query + } +} + +func (c *conn) handleDescribe(m *pgproto.Describe) { + if m.ObjectType == 'S' { + c.stmtMu.Lock() + c.pendingDescribes = append(c.pendingDescribes, m.Name) + c.stmtMu.Unlock() + } +} + +// handleParameterDescription captures the server-resolved parameter OIDs +// returned by the upstream in response to a Describe(Statement) message. +// These OIDs are authoritative — they override the OIDs from Parse, which +// are often all zeros (meaning "let the server decide"). +// Responses arrive in the same order as the corresponding Describe requests, +// so we pop from the front of pendingDescribes to match them. +func (c *conn) handleParameterDescription(m *pgproto.ParameterDescription) { + c.stmtMu.Lock() + defer c.stmtMu.Unlock() + + if len(c.pendingDescribes) == 0 { + return + } + name := c.pendingDescribes[0] + c.pendingDescribes = c.pendingDescribes[1:] + + if name == "" { + // Unnamed statement: update the fallback OIDs used by unnamed binds. + c.lastParamOIDs = m.ParameterOIDs + } else { + // Named statement: only update its entry without touching lastParamOIDs. + c.preparedStmtOIDs[name] = m.ParameterOIDs + } +} + +// drainPendingDescribes clears any unmatched Describe entries from the queue. +// Called on ReadyForQuery, which marks the end of a query cycle — any pending +// entries at this point were skipped by the server due to an earlier error. +func (c *conn) drainPendingDescribes() { + c.stmtMu.Lock() + c.pendingDescribes = nil + c.stmtMu.Unlock() } func (c *conn) handleBind(m *pgproto.Bind) { c.lastBindStmt = m.PreparedStatement + c.stmtMu.Lock() paramOIDs := c.lastParamOIDs if m.PreparedStatement != "" { if oids, ok := c.preparedStmtOIDs[m.PreparedStatement]; ok { paramOIDs = oids } } + c.stmtMu.Unlock() c.lastBindArgs = make([]string, len(m.Parameters)) for i, p := range m.Parameters { oid := uint32(0) @@ -363,7 +427,7 @@ func decodeBinaryParam(p []byte, oid uint32) string { switch len(p) { case 1: // bool or int8 - return strconv.Itoa(int(int8(p[0]))) + return strconv.Itoa(int(int8(p[0]))) //nolint:gosec // interpreting as signed int8 case 2: return strconv.Itoa(int(int16(binary.BigEndian.Uint16(p)))) //nolint:gosec // interpreting as signed int16 case 4: diff --git a/proxy/postgres/conn_test.go b/proxy/postgres/conn_test.go index a804e9c..477643d 100644 --- a/proxy/postgres/conn_test.go +++ b/proxy/postgres/conn_test.go @@ -1,12 +1,179 @@ package postgres_test import ( + "encoding/binary" "testing" "time" pgproxy "github.com/mickamy/sql-tap/proxy/postgres" ) +// encodeMicros encodes microseconds as big-endian 8 bytes (PostgreSQL binary timestamp format). +func encodeMicros(us int64) []byte { + b := make([]byte, 8) + binary.BigEndian.PutUint64(b, uint64(us)) //nolint:gosec // test helper: intentional signed→unsigned reinterpretation + return b +} + +func TestParameterDescriptionFlow(t *testing.T) { + t.Parallel() + + t.Run("unnamed bind picks up ParameterDescription OIDs", func(t *testing.T) { + t.Parallel() + + tc := pgproxy.NewTestConn() + + // Parse with OID=0 (server-inferred), then Describe unnamed statement. + tc.HandleParse("", "SELECT id FROM t WHERE ts < $1", []uint32{0}) + tc.HandleDescribe("") + + // Server responds with the actual OID (timestamptz). + tc.HandleParameterDescription([]uint32{pgproxy.OIDTimestampTZ}) + + // Bind with binary timestamp — should decode as RFC3339 thanks to resolved OID. + tc.HandleBind("", [][]byte{encodeMicros(826159500119733)}, []int16{1}) + + args := tc.LastBindArgs() + if len(args) != 1 { + t.Fatalf("got %d args, want 1", len(args)) + } + if _, err := time.Parse(time.RFC3339Nano, args[0]); err != nil { + t.Errorf("arg = %q, want RFC3339 parseable string", args[0]) + } + }) + + t.Run("named statement uses per-statement OIDs", func(t *testing.T) { + t.Parallel() + + tc := pgproxy.NewTestConn() + + // Parse named statement with OID=0. + tc.HandleParse("s1", "SELECT id FROM t WHERE ts < $1", []uint32{0}) + tc.HandleDescribe("s1") + + // Server responds with actual OID. + tc.HandleParameterDescription([]uint32{pgproxy.OIDTimestamp}) + + // Bind referencing the named statement. + tc.HandleBind("s1", [][]byte{encodeMicros(826159500119733)}, []int16{1}) + + args := tc.LastBindArgs() + if len(args) != 1 { + t.Fatalf("got %d args, want 1", len(args)) + } + if _, err := time.Parse(time.RFC3339Nano, args[0]); err != nil { + t.Errorf("arg = %q, want RFC3339 parseable string", args[0]) + } + }) + + t.Run("named statement OIDs do not pollute unnamed bind", func(t *testing.T) { + t.Parallel() + + tc := pgproxy.NewTestConn() + + // Parse + Describe unnamed statement with OID=0 (no ParameterDescription yet). + tc.HandleParse("", "SELECT id FROM t WHERE id = $1", []uint32{0}) + + // Parse + Describe named statement — server resolves as timestamp. + tc.HandleParse("s1", "SELECT id FROM t WHERE ts < $1", []uint32{0}) + tc.HandleDescribe("s1") + tc.HandleParameterDescription([]uint32{pgproxy.OIDTimestamp}) + + // Bind unnamed — should still use OID=0 (not timestamp from s1). + tc.HandleBind("", [][]byte{encodeMicros(42)}, []int16{1}) + + args := tc.LastBindArgs() + if len(args) != 1 { + t.Fatalf("got %d args, want 1", len(args)) + } + // OID=0 with 8-byte binary → plain int64, not RFC3339. + if args[0] != "42" { + t.Errorf("arg = %q, want %q (should not be treated as timestamp)", args[0], "42") + } + }) + + t.Run("ReadyForQuery drains stale pending describes", func(t *testing.T) { + t.Parallel() + + tc := pgproxy.NewTestConn() + + // Parse fails on server, Describe is skipped — no ParameterDescription arrives. + tc.HandleParse("", "INVALID SQL", []uint32{0}) + tc.HandleDescribe("") + + // ReadyForQuery clears the stale entry. + tc.HandleReadyForQuery() + + // Next cycle: new Parse + Describe for a real query. + tc.HandleParse("", "SELECT id FROM t WHERE ts < $1", []uint32{0}) + tc.HandleDescribe("") + tc.HandleParameterDescription([]uint32{pgproxy.OIDTimestampTZ}) + + tc.HandleBind("", [][]byte{encodeMicros(826159500119733)}, []int16{1}) + + args := tc.LastBindArgs() + if len(args) != 1 { + t.Fatalf("got %d args, want 1", len(args)) + } + if _, err := time.Parse(time.RFC3339Nano, args[0]); err != nil { + t.Errorf("arg = %q, want RFC3339 parseable string", args[0]) + } + }) +} + +func TestDecodeBinaryParam(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + data []byte + oid uint32 + wantTime bool // true if the result should parse as RFC3339 + wantString string + }{ + { + name: "timestamp OID decodes as RFC3339", + data: encodeMicros(826159500119733), + oid: pgproxy.OIDTimestamp, + wantTime: true, + }, + { + name: "timestamptz OID decodes as RFC3339", + data: encodeMicros(826159500119733), + oid: pgproxy.OIDTimestampTZ, + wantTime: true, + }, + { + name: "zero OID 8-byte value decoded as plain int64", + data: encodeMicros(826159500119733), + oid: 0, + wantTime: false, + wantString: "826159500119733", + }, + { + name: "4-byte int32", + data: []byte{0, 0, 0, 42}, + oid: 0, + wantString: "42", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + got := pgproxy.DecodeBinaryParam(tt.data, tt.oid) + if tt.wantTime { + if _, err := time.Parse(time.RFC3339Nano, got); err != nil { + t.Errorf("DecodeBinaryParam() = %q, want RFC3339 parseable string", got) + } + } else if tt.wantString != "" && got != tt.wantString { + t.Errorf("DecodeBinaryParam() = %q, want %q", got, tt.wantString) + } + }) + } +} + func TestDecodePGTimestampMicros(t *testing.T) { t.Parallel() diff --git a/proxy/postgres/export_test.go b/proxy/postgres/export_test.go index 9bb3305..c8d65ff 100644 --- a/proxy/postgres/export_test.go +++ b/proxy/postgres/export_test.go @@ -1,5 +1,60 @@ package postgres +import ( + pgproto "github.com/jackc/pgproto3/v2" + + "github.com/mickamy/sql-tap/proxy" +) + // Exported wrappers for internal symbols used in package-external tests. var DecodePGTimestampMicros = decodePGTimestampMicros + +// DecodeBinaryParam exposes decodeBinaryParam for testing. +var DecodeBinaryParam = decodeBinaryParam + +// OID constants for testing. +const ( + OIDTimestamp = oidTimestamp + OIDTimestampTZ = oidTimestampTZ +) + +// TestConn wraps conn for protocol-level unit tests. +type TestConn struct{ c *conn } + +// NewTestConn creates a minimal conn for testing the extended query flow. +func NewTestConn() *TestConn { + return &TestConn{c: &conn{ + preparedStmts: make(map[string]string), + preparedStmtOIDs: make(map[string][]uint32), + events: make(chan<- proxy.Event, 16), + }} +} + +func (tc *TestConn) HandleParse(name, query string, oids []uint32) { + tc.c.handleParse(&pgproto.Parse{Name: name, Query: query, ParameterOIDs: oids}) +} + +func (tc *TestConn) HandleDescribe(name string) { + tc.c.handleDescribe(&pgproto.Describe{ObjectType: 'S', Name: name}) +} + +func (tc *TestConn) HandleParameterDescription(oids []uint32) { + tc.c.handleParameterDescription(&pgproto.ParameterDescription{ParameterOIDs: oids}) +} + +func (tc *TestConn) HandleBind(stmtName string, params [][]byte, formatCodes []int16) { + tc.c.handleBind(&pgproto.Bind{ + PreparedStatement: stmtName, + Parameters: params, + ParameterFormatCodes: formatCodes, + }) +} + +func (tc *TestConn) HandleReadyForQuery() { + tc.c.drainPendingDescribes() +} + +func (tc *TestConn) LastBindArgs() []string { + return tc.c.lastBindArgs +} diff --git a/proxy/postgres/proxy_test.go b/proxy/postgres/proxy_test.go index 97292aa..33af736 100644 --- a/proxy/postgres/proxy_test.go +++ b/proxy/postgres/proxy_test.go @@ -70,7 +70,7 @@ func startProxy(t *testing.T, upstream string) (*pproxy.Proxy, string) { _ = lis.Close() p := pproxy.New(addr, upstream) - ctx, cancel := context.WithCancel(t.Context()) + ctx, cancel := context.WithCancel(t.Context()) //nolint:gosec // cancel is deferred below via t.Cleanup go func() { if err := p.ListenAndServe(ctx); err != nil { diff --git a/tui/editor.go b/tui/editor.go index 22d84ea..6903629 100644 --- a/tui/editor.go +++ b/tui/editor.go @@ -37,7 +37,7 @@ func openEditor(query string, args []string, mode explain.Mode) tea.Cmd { if _, err := f.WriteString(header + query); err != nil { _ = f.Close() - _ = os.Remove(path) //nolint:gosec // path is a controlled temp file created by this function + _ = os.Remove(path) return func() tea.Msg { return editorResultMsg{err: err, mode: mode} } diff --git a/web/web_test.go b/web/web_test.go index abdc101..a855950 100644 --- a/web/web_test.go +++ b/web/web_test.go @@ -24,7 +24,7 @@ func TestStaticFiles(t *testing.T) { ctx := context.Background() req, _ := http.NewRequestWithContext(ctx, http.MethodGet, ts.URL+"/", nil) - resp, err := http.DefaultClient.Do(req) //nolint:gosec // test URL + resp, err := http.DefaultClient.Do(req) if err != nil { t.Fatal(err) } @@ -51,7 +51,7 @@ func TestSSE_ReceivesEvents(t *testing.T) { defer cancel() req, _ := http.NewRequestWithContext(ctx, http.MethodGet, ts.URL+"/api/events", nil) - resp, err := http.DefaultClient.Do(req) //nolint:gosec // test URL + resp, err := http.DefaultClient.Do(req) if err != nil { t.Fatal(err) } @@ -116,7 +116,7 @@ func TestSSE_DisconnectUnsubscribes(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) req, _ := http.NewRequestWithContext(ctx, http.MethodGet, ts.URL+"/api/events", nil) - resp, err := http.DefaultClient.Do(req) //nolint:gosec // test URL + resp, err := http.DefaultClient.Do(req) if err != nil { t.Fatal(err) } @@ -147,7 +147,7 @@ func TestExplain_NotConfigured(t *testing.T) { ctx := context.Background() req, _ := http.NewRequestWithContext(ctx, http.MethodPost, ts.URL+"/api/explain", body) req.Header.Set("Content-Type", "application/json") - resp, err := http.DefaultClient.Do(req) //nolint:gosec // test URL + resp, err := http.DefaultClient.Do(req) if err != nil { t.Fatal(err) }