Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 3 additions & 6 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ on:
- "**"

permissions:
contents: write
contents: read

jobs:
lint:
Expand All @@ -21,22 +21,19 @@ jobs:

- uses: actions/setup-go@v5
with:
go-version: "1.25"
go-version: "1.26"

- name: golangci-lint
uses: golangci/golangci-lint-action@v9

test:
runs-on: ubuntu-latest
strategy:
matrix:
go-version: [ "1.25", "1.26" ]
steps:
- uses: actions/checkout@v6

- uses: actions/setup-go@v5
with:
go-version: ${{ matrix.go-version }}
go-version: "1.26"

- name: Run tests
run: make test
23 changes: 11 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,26 +65,25 @@ Then reference it in `devenv.nix`:

### Docker

**PostgreSQL**

```dockerfile
FROM postgres:18-alpine
FROM alpine:3
ARG SQL_TAP_VERSION=0.0.1
ARG TARGETARCH
ADD https://github.com/mickamy/sql-tap/releases/download/v${SQL_TAP_VERSION}/sql-tap_${SQL_TAP_VERSION}_linux_${TARGETARCH}.tar.gz /tmp/sql-tap.tar.gz
RUN tar -xzf /tmp/sql-tap.tar.gz -C /usr/local/bin sql-tapd && rm /tmp/sql-tap.tar.gz
ENTRYPOINT ["sql-tapd", "--driver=postgres", "--listen=:5433", "--upstream=localhost:5432", "--grpc=:9091"]
ENTRYPOINT ["sql-tapd"]
```

**MySQL**
Run as a sidecar alongside your database:

```dockerfile
FROM mysql:8
ARG SQL_TAP_VERSION=0.0.1
ARG TARGETARCH
ADD https://github.com/mickamy/sql-tap/releases/download/v${SQL_TAP_VERSION}/sql-tap_${SQL_TAP_VERSION}_linux_${TARGETARCH}.tar.gz /tmp/sql-tap.tar.gz
RUN tar -xzf /tmp/sql-tap.tar.gz -C /usr/local/bin sql-tapd && rm /tmp/sql-tap.tar.gz
ENTRYPOINT ["sql-tapd", "--driver=mysql", "--listen=:3307", "--upstream=localhost:3306", "--grpc=:9091"]
```bash
# PostgreSQL
docker run --rm --network=host sql-tap \
--driver=postgres --listen=:5433 --upstream=localhost:5432 --grpc=:9091

# MySQL
docker run --rm --network=host sql-tap \
--driver=mysql --listen=:3307 --upstream=localhost:3306 --grpc=:9091
```

## Quick start
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module github.com/mickamy/sql-tap

go 1.25.0
go 1.26.1
Copy link

Copilot AI Mar 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The PR title is specific to PostgreSQL ParameterDescription handling, but this change also bumps the module Go version (and separately modifies CI/docs). This makes the PR harder to review and can complicate rollback. Consider splitting the Go/tooling/docs changes into separate PR(s) unless they are required for the protocol fix.

Copilot uses AI. Check for mistakes.

require (
github.com/alecthomas/chroma/v2 v2.23.1
Expand Down
2 changes: 1 addition & 1 deletion proxy/mysql/proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ func startProxy(t *testing.T, upstream string) (*mproxy.Proxy, string) {
_ = lis.Close()

p := mproxy.New(addr, upstream)
ctx, cancel := context.WithCancel(t.Context())
ctx, cancel := context.WithCancel(t.Context()) //nolint:gosec // cancel is deferred below via t.Cleanup
Copy link

Copilot AI Mar 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The inline suppression uses //nolint:gosec, but the warning for context.WithCancel without an immediate cancel() is typically from govet's lostcancel analyzer, not gosec. As written, this comment likely won’t suppress the CI finding; use the correct nolint tag (e.g. lostcancel/govet) or adjust the code in a way the analyzer recognizes.

Suggested change
ctx, cancel := context.WithCancel(t.Context()) //nolint:gosec // cancel is deferred below via t.Cleanup
ctx, cancel := context.WithCancel(t.Context()) //nolint:govet // cancel is deferred below via t.Cleanup

Copilot uses AI. Check for mistakes.

go func() {
if err := p.ListenAndServe(ctx); err != nil {
Expand Down
68 changes: 66 additions & 2 deletions proxy/postgres/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,22 @@ type conn struct {
events chan<- proxy.Event

// Extended query state.
// preparedStmts is only accessed by the client→upstream goroutine.
preparedStmts map[string]string // stmt name -> query
preparedStmtOIDs map[string][]uint32 // stmt name -> parameter OIDs
lastParse string // query from most recent Parse
lastParamOIDs []uint32 // parameter OIDs from most recent Parse
lastBindArgs []string // args from most recent Bind
lastBindStmt string // stmt name from most recent Bind
// pendingDescribes is a FIFO queue of statement names from Describe('S')
// messages. ParameterDescription responses arrive in the same order, so
// we pop from the front to match each response to its request.
pendingDescribes []string

// stmtMu protects OID-related fields that are written by
// handleParameterDescription (upstream→client goroutine) and read by
// handleBind / written by handleParse (client→upstream goroutine).
stmtMu sync.Mutex

// Transaction tracking.
activeTxID string
Expand Down Expand Up @@ -275,6 +285,8 @@ func (c *conn) captureClientMsg(msg pgproto.FrontendMessage) {
c.handleSimpleQuery(m)
case *pgproto.Parse:
c.handleParse(m)
case *pgproto.Describe:
c.handleDescribe(m)
case *pgproto.Bind:
c.handleBind(m)
case *pgproto.Execute:
Expand All @@ -284,10 +296,14 @@ func (c *conn) captureClientMsg(msg pgproto.FrontendMessage) {

func (c *conn) captureUpstreamMsg(msg pgproto.BackendMessage) {
switch m := msg.(type) {
case *pgproto.ParameterDescription:
c.handleParameterDescription(m)
case *pgproto.CommandComplete:
c.handleCommandComplete(m)
case *pgproto.ErrorResponse:
c.handleErrorResponse(m)
case *pgproto.ReadyForQuery:
c.drainPendingDescribes()
}
}

Expand All @@ -309,21 +325,69 @@ func (c *conn) handleSimpleQuery(m *pgproto.Query) {

func (c *conn) handleParse(m *pgproto.Parse) {
c.lastParse = m.Query
c.stmtMu.Lock()
c.lastParamOIDs = m.ParameterOIDs
if m.Name != "" {
c.preparedStmts[m.Name] = m.Query
c.preparedStmtOIDs[m.Name] = m.ParameterOIDs
}
c.stmtMu.Unlock()
if m.Name != "" {
c.preparedStmts[m.Name] = m.Query
}
}

func (c *conn) handleDescribe(m *pgproto.Describe) {
if m.ObjectType == 'S' {
c.stmtMu.Lock()
c.pendingDescribes = append(c.pendingDescribes, m.Name)
c.stmtMu.Unlock()
Comment on lines +339 to +343
Copy link

Copilot AI Mar 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lastDescribeStmt stores only a single statement name, but ParameterDescription responses arrive asynchronously from the upstream goroutine and PostgreSQL allows pipelining multiple Describe('S') messages. If multiple Describe(Statement) messages are in-flight, lastDescribeStmt can be overwritten before the first ParameterDescription arrives, causing OIDs to be associated with the wrong prepared statement. Consider tracking a FIFO queue of pending described statement names (push in handleDescribe, pop in handleParameterDescription) so responses are matched to the correct statement.

Copilot uses AI. Check for mistakes.
}
}

// handleParameterDescription captures the server-resolved parameter OIDs
// returned by the upstream in response to a Describe(Statement) message.
// These OIDs are authoritative — they override the OIDs from Parse, which
// are often all zeros (meaning "let the server decide").
// Responses arrive in the same order as the corresponding Describe requests,
// so we pop from the front of pendingDescribes to match them.
func (c *conn) handleParameterDescription(m *pgproto.ParameterDescription) {
c.stmtMu.Lock()
defer c.stmtMu.Unlock()

if len(c.pendingDescribes) == 0 {
return
}
name := c.pendingDescribes[0]
c.pendingDescribes = c.pendingDescribes[1:]

if name == "" {
// Unnamed statement: update the fallback OIDs used by unnamed binds.
c.lastParamOIDs = m.ParameterOIDs
} else {
// Named statement: only update its entry without touching lastParamOIDs.
c.preparedStmtOIDs[name] = m.ParameterOIDs
}
Comment on lines +339 to +369
Copy link

Copilot AI Mar 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pendingDescribes is only advanced when a ParameterDescription arrives. If the upstream returns an ErrorResponse for a Describe('S') (e.g., unknown statement), there will be no ParameterDescription, leaving a stale entry in the FIFO and causing the next ParameterDescription to be associated with the wrong statement. Consider clearing (or popping) pending describe entries on relevant ErrorResponse/sync boundaries (e.g., in handleErrorResponse under stmtMu) to keep the queue in sync and avoid mis-decoding subsequent bind params.

Copilot uses AI. Check for mistakes.
Comment on lines +347 to +369
Copy link

Copilot AI Mar 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New behavior captures server-resolved parameter OIDs via Describe/ParameterDescription, but there are no tests asserting that (a) unnamed binds pick up ParameterDescription OIDs and (b) named prepared statements override per-statement OIDs. Adding a small protocol-level/unit test around handleDescribe + handleParameterDescription + handleBind would prevent regressions where timestamps fall back to OID=0 and decode incorrectly.

Copilot uses AI. Check for mistakes.
}

// drainPendingDescribes clears any unmatched Describe entries from the queue.
// Called on ReadyForQuery, which marks the end of a query cycle — any pending
// entries at this point were skipped by the server due to an earlier error.
func (c *conn) drainPendingDescribes() {
c.stmtMu.Lock()
c.pendingDescribes = nil
c.stmtMu.Unlock()
}

func (c *conn) handleBind(m *pgproto.Bind) {
c.lastBindStmt = m.PreparedStatement
c.stmtMu.Lock()
paramOIDs := c.lastParamOIDs
if m.PreparedStatement != "" {
if oids, ok := c.preparedStmtOIDs[m.PreparedStatement]; ok {
paramOIDs = oids
}
}
c.stmtMu.Unlock()
c.lastBindArgs = make([]string, len(m.Parameters))
for i, p := range m.Parameters {
oid := uint32(0)
Expand Down Expand Up @@ -363,7 +427,7 @@ func decodeBinaryParam(p []byte, oid uint32) string {
switch len(p) {
case 1:
// bool or int8
return strconv.Itoa(int(int8(p[0])))
return strconv.Itoa(int(int8(p[0]))) //nolint:gosec // interpreting as signed int8
case 2:
return strconv.Itoa(int(int16(binary.BigEndian.Uint16(p)))) //nolint:gosec // interpreting as signed int16
case 4:
Expand Down
167 changes: 167 additions & 0 deletions proxy/postgres/conn_test.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,179 @@
package postgres_test

import (
"encoding/binary"
"testing"
"time"

pgproxy "github.com/mickamy/sql-tap/proxy/postgres"
)

// encodeMicros encodes microseconds as big-endian 8 bytes (PostgreSQL binary timestamp format).
func encodeMicros(us int64) []byte {
b := make([]byte, 8)
binary.BigEndian.PutUint64(b, uint64(us)) //nolint:gosec // test helper: intentional signed→unsigned reinterpretation
return b
}

func TestParameterDescriptionFlow(t *testing.T) {
t.Parallel()

t.Run("unnamed bind picks up ParameterDescription OIDs", func(t *testing.T) {
t.Parallel()

tc := pgproxy.NewTestConn()

// Parse with OID=0 (server-inferred), then Describe unnamed statement.
tc.HandleParse("", "SELECT id FROM t WHERE ts < $1", []uint32{0})
tc.HandleDescribe("")

// Server responds with the actual OID (timestamptz).
tc.HandleParameterDescription([]uint32{pgproxy.OIDTimestampTZ})

// Bind with binary timestamp — should decode as RFC3339 thanks to resolved OID.
tc.HandleBind("", [][]byte{encodeMicros(826159500119733)}, []int16{1})

args := tc.LastBindArgs()
if len(args) != 1 {
t.Fatalf("got %d args, want 1", len(args))
}
if _, err := time.Parse(time.RFC3339Nano, args[0]); err != nil {
t.Errorf("arg = %q, want RFC3339 parseable string", args[0])
}
})

t.Run("named statement uses per-statement OIDs", func(t *testing.T) {
t.Parallel()

tc := pgproxy.NewTestConn()

// Parse named statement with OID=0.
tc.HandleParse("s1", "SELECT id FROM t WHERE ts < $1", []uint32{0})
tc.HandleDescribe("s1")

// Server responds with actual OID.
tc.HandleParameterDescription([]uint32{pgproxy.OIDTimestamp})

// Bind referencing the named statement.
tc.HandleBind("s1", [][]byte{encodeMicros(826159500119733)}, []int16{1})

args := tc.LastBindArgs()
if len(args) != 1 {
t.Fatalf("got %d args, want 1", len(args))
}
if _, err := time.Parse(time.RFC3339Nano, args[0]); err != nil {
t.Errorf("arg = %q, want RFC3339 parseable string", args[0])
}
})

t.Run("named statement OIDs do not pollute unnamed bind", func(t *testing.T) {
t.Parallel()

tc := pgproxy.NewTestConn()

// Parse + Describe unnamed statement with OID=0 (no ParameterDescription yet).
tc.HandleParse("", "SELECT id FROM t WHERE id = $1", []uint32{0})

// Parse + Describe named statement — server resolves as timestamp.
tc.HandleParse("s1", "SELECT id FROM t WHERE ts < $1", []uint32{0})
tc.HandleDescribe("s1")
tc.HandleParameterDescription([]uint32{pgproxy.OIDTimestamp})

// Bind unnamed — should still use OID=0 (not timestamp from s1).
tc.HandleBind("", [][]byte{encodeMicros(42)}, []int16{1})

args := tc.LastBindArgs()
if len(args) != 1 {
t.Fatalf("got %d args, want 1", len(args))
}
// OID=0 with 8-byte binary → plain int64, not RFC3339.
if args[0] != "42" {
t.Errorf("arg = %q, want %q (should not be treated as timestamp)", args[0], "42")
}
})

t.Run("ReadyForQuery drains stale pending describes", func(t *testing.T) {
t.Parallel()

tc := pgproxy.NewTestConn()

// Parse fails on server, Describe is skipped — no ParameterDescription arrives.
tc.HandleParse("", "INVALID SQL", []uint32{0})
tc.HandleDescribe("")

// ReadyForQuery clears the stale entry.
tc.HandleReadyForQuery()

// Next cycle: new Parse + Describe for a real query.
tc.HandleParse("", "SELECT id FROM t WHERE ts < $1", []uint32{0})
tc.HandleDescribe("")
tc.HandleParameterDescription([]uint32{pgproxy.OIDTimestampTZ})

tc.HandleBind("", [][]byte{encodeMicros(826159500119733)}, []int16{1})

args := tc.LastBindArgs()
if len(args) != 1 {
t.Fatalf("got %d args, want 1", len(args))
}
if _, err := time.Parse(time.RFC3339Nano, args[0]); err != nil {
t.Errorf("arg = %q, want RFC3339 parseable string", args[0])
}
})
}

func TestDecodeBinaryParam(t *testing.T) {
t.Parallel()

tests := []struct {
name string
data []byte
oid uint32
wantTime bool // true if the result should parse as RFC3339
wantString string
}{
{
name: "timestamp OID decodes as RFC3339",
data: encodeMicros(826159500119733),
oid: pgproxy.OIDTimestamp,
wantTime: true,
},
{
name: "timestamptz OID decodes as RFC3339",
data: encodeMicros(826159500119733),
oid: pgproxy.OIDTimestampTZ,
wantTime: true,
},
{
name: "zero OID 8-byte value decoded as plain int64",
data: encodeMicros(826159500119733),
oid: 0,
wantTime: false,
wantString: "826159500119733",
},
{
name: "4-byte int32",
data: []byte{0, 0, 0, 42},
oid: 0,
wantString: "42",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()

got := pgproxy.DecodeBinaryParam(tt.data, tt.oid)
if tt.wantTime {
if _, err := time.Parse(time.RFC3339Nano, got); err != nil {
t.Errorf("DecodeBinaryParam() = %q, want RFC3339 parseable string", got)
}
} else if tt.wantString != "" && got != tt.wantString {
t.Errorf("DecodeBinaryParam() = %q, want %q", got, tt.wantString)
}
})
}
}

func TestDecodePGTimestampMicros(t *testing.T) {
t.Parallel()

Expand Down
Loading