diff --git a/explain/build_args_test.go b/explain/build_args_test.go new file mode 100644 index 0000000..ffdd1c8 --- /dev/null +++ b/explain/build_args_test.go @@ -0,0 +1,236 @@ +package explain_test + +import ( + "testing" + "time" + + "github.com/mickamy/sql-tap/explain" +) + +func TestParseTimestampParams(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + query string + want map[int]bool + }{ + { + name: "no casts", + query: "SELECT * FROM users WHERE id = $1", + want: map[int]bool{}, + }, + { + name: "timestamp with time zone", + query: "SELECT * FROM t WHERE created_at > $1::TIMESTAMP WITH TIME ZONE", + want: map[int]bool{1: true}, + }, + { + name: "timestamptz", + query: "SELECT * FROM t WHERE ts = $2::TIMESTAMPTZ", + want: map[int]bool{2: true}, + }, + { + name: "timestamp without time zone", + query: "SELECT * FROM t WHERE ts = $3::TIMESTAMP WITHOUT TIME ZONE", + want: map[int]bool{3: true}, + }, + { + name: "plain timestamp", + query: "SELECT * FROM t WHERE ts = $1::TIMESTAMP", + want: map[int]bool{1: true}, + }, + { + name: "lowercase cast", + query: "SELECT * FROM t WHERE ts > $1::timestamp with time zone", + want: map[int]bool{1: true}, + }, + { + name: "spaces around ::", + query: "SELECT * FROM t WHERE ts = $1 :: TIMESTAMP", + want: map[int]bool{1: true}, + }, + { + name: "mixed: timestamp and non-timestamp", + query: "SELECT * FROM t WHERE key = $1::VARCHAR AND ts > $2::TIMESTAMP WITH TIME ZONE", + want: map[int]bool{2: true}, + }, + { + name: "multiple timestamp params", + query: "SELECT * FROM t WHERE a > $1::TIMESTAMP AND b < $3::TIMESTAMPTZ", + want: map[int]bool{1: true, 3: true}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + got := explain.ParseTimestampParams(tt.query) + if len(got) != len(tt.want) { + t.Fatalf("ParseTimestampParams(%q) = %v, want %v", tt.query, got, tt.want) + } + for k, v := range tt.want { + if got[k] != v { + t.Errorf("ParseTimestampParams(%q)[%d] = %v, want %v", tt.query, k, got[k], v) + } + } + }) + } +} + +func TestParsePGTimestamp(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + input string + want time.Time + wantOK bool + }{ + { + name: "PostgreSQL epoch (zero)", + input: "0", + want: time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), + wantOK: true, + }, + { + name: "large microseconds value (issue repro: ~2026)", + input: "825505830505628", + want: func() time.Time { + microsecs := int64(825505830505628) + sec := microsecs / 1_000_000 + usec := microsecs % 1_000_000 + return time.Unix(sec+explain.PgEpochUnix, usec*1_000).UTC() + }(), + wantOK: true, + }, + { + name: "negative (before 2000-01-01)", + input: "-1000000", + want: time.Date(1999, 12, 31, 23, 59, 59, 0, time.UTC), + wantOK: true, + }, + { + name: "negative fractional (before 2000-01-01)", + input: "-1", + want: time.Date(1999, 12, 31, 23, 59, 59, 999999000, time.UTC), + wantOK: true, + }, + { + name: "non-integer string", + input: "2026-02-27T14:10:30Z", + wantOK: false, + }, + { + name: "float string", + input: "1.5", + wantOK: false, + }, + { + name: "empty string", + input: "", + wantOK: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + got, ok := explain.ParsePGTimestamp(tt.input) + if ok != tt.wantOK { + t.Fatalf("ParsePGTimestamp(%q) ok = %v, want %v", tt.input, ok, tt.wantOK) + } + if ok && !got.Equal(tt.want) { + t.Errorf("ParsePGTimestamp(%q) = %v, want %v", tt.input, got, tt.want) + } + }) + } +} + +func TestBuildAnyArgs(t *testing.T) { + t.Parallel() + + pgEpoch := time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC) + + tests := []struct { + name string + query string + args []string + check func(t *testing.T, got []any) + }{ + { + name: "no args", + query: "SELECT 1", + args: nil, + check: func(t *testing.T, got []any) { + t.Helper() + if len(got) != 0 { + t.Errorf("expected empty slice, got %v", got) + } + }, + }, + { + name: "non-timestamp arg stays as string", + query: "SELECT * FROM users WHERE id = $1", + args: []string{"42"}, + check: func(t *testing.T, got []any) { + t.Helper() + if s, ok := got[0].(string); !ok || s != "42" { + t.Errorf("got[0] = %v (%T), want string %q", got[0], got[0], "42") + } + }, + }, + { + name: "timestamp arg converted to time.Time", + query: "SELECT * FROM t WHERE expired_at > $1::TIMESTAMP WITH TIME ZONE", + args: []string{"825505830505628"}, + check: func(t *testing.T, got []any) { + t.Helper() + ts, ok := got[0].(time.Time) + if !ok { + t.Fatalf("got[0] = %v (%T), want time.Time", got[0], got[0]) + } + // The value should be ~2026 (pgEpoch + 825505830.5 seconds) + if ts.Before(pgEpoch) { + t.Errorf("got time %v, expected a time after 2000-01-01", ts) + } + }, + }, + { + name: "non-integer timestamp arg stays as string", + query: "SELECT * FROM t WHERE ts > $1::TIMESTAMP WITH TIME ZONE", + args: []string{"2026-02-27T14:10:30Z"}, + check: func(t *testing.T, got []any) { + t.Helper() + if s, ok := got[0].(string); !ok || s != "2026-02-27T14:10:30Z" { + t.Errorf("got[0] = %v (%T), want string %q", got[0], got[0], "2026-02-27T14:10:30Z") + } + }, + }, + { + name: "mixed args: varchar and timestamp", + query: "SELECT * FROM t WHERE key = $1::VARCHAR AND ts > $2::TIMESTAMP WITH TIME ZONE", + args: []string{"019c5c4f-f25a-772b-97d4-1646a125080d", "825505830505628"}, + check: func(t *testing.T, got []any) { + t.Helper() + if s, ok := got[0].(string); !ok || s != "019c5c4f-f25a-772b-97d4-1646a125080d" { + t.Errorf("got[0] = %v (%T), want string", got[0], got[0]) + } + if _, ok := got[1].(time.Time); !ok { + t.Errorf("got[1] = %v (%T), want time.Time", got[1], got[1]) + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + got := explain.BuildAnyArgs(tt.query, tt.args) + tt.check(t, got) + }) + } +} diff --git a/explain/explain.go b/explain/explain.go index 2e16639..b693fb1 100644 --- a/explain/explain.go +++ b/explain/explain.go @@ -4,6 +4,8 @@ import ( "context" "database/sql" "fmt" + "regexp" + "strconv" "strings" "time" ) @@ -74,10 +76,7 @@ func NewClient(db *sql.DB, driver Driver) *Client { // Run executes EXPLAIN or EXPLAIN ANALYZE for the given query with optional args. func (c *Client) Run(ctx context.Context, mode Mode, query string, args []string) (*Result, error) { - anyArgs := make([]any, len(args)) - for i, a := range args { - anyArgs[i] = a - } + anyArgs := buildAnyArgs(query, args) // MySQL/TiDB cannot parse placeholder ? without args; replace with NULL for plan-only EXPLAIN. q := query @@ -126,6 +125,67 @@ func (c *Client) Run(ctx context.Context, mode Mode, query string, args []string }, nil } +// timestampCastRe matches PostgreSQL-style timestamp cast placeholders such as +// $1::TIMESTAMP, $2::TIMESTAMPTZ, $3::TIMESTAMP WITH TIME ZONE, etc. +// The prefix "TIMESTAMP" covers all variants because TIMESTAMPTZ and +// "TIMESTAMP WITH/WITHOUT TIME ZONE" all begin with that substring. +var timestampCastRe = regexp.MustCompile(`(?i)\$(\d+)\s*::\s*TIMESTAMP`) + +// pgEpochUnix is the Unix timestamp of PostgreSQL's internal epoch (2000-01-01 00:00:00 UTC). +const pgEpochUnix int64 = 946684800 + +// buildAnyArgs converts string args to []any for use in QueryContext. +// For args whose corresponding query placeholder is cast to a timestamp type +// (e.g. $2::TIMESTAMP WITH TIME ZONE), it tries to interpret the value as a +// PostgreSQL binary-encoded timestamp (int64 microseconds since 2000-01-01 UTC) +// and converts it to time.Time. This prevents the "date/time field value out of +// range" error that occurs when a captured binary timestamp is re-used as a plain +// string in a parameterized EXPLAIN query. +func buildAnyArgs(query string, args []string) []any { + tsParams := parseTimestampParams(query) + anyArgs := make([]any, len(args)) + for i, a := range args { + if tsParams[i+1] { + if t, ok := parsePGTimestamp(a); ok { + anyArgs[i] = t + continue + } + } + anyArgs[i] = a + } + return anyArgs +} + +// parseTimestampParams returns the set of 1-indexed parameter numbers that are +// cast to a timestamp type in the query. +func parseTimestampParams(query string) map[int]bool { + m := make(map[int]bool) + for _, match := range timestampCastRe.FindAllStringSubmatch(query, -1) { + if n, err := strconv.Atoi(match[1]); err == nil { + m[n] = true + } + } + return m +} + +// parsePGTimestamp attempts to interpret s as a PostgreSQL binary-encoded +// timestamp: an int64 number of microseconds since 2000-01-01 00:00:00 UTC. +// Returns the corresponding time.Time and true on success. +func parsePGTimestamp(s string) (time.Time, bool) { + microsecs, err := strconv.ParseInt(s, 10, 64) + if err != nil { + return time.Time{}, false + } + sec := microsecs / 1_000_000 + usec := microsecs % 1_000_000 + // Normalize: for negative microsecs, usec is negative; carry into sec. + if usec < 0 { + sec-- + usec += 1_000_000 + } + return time.Unix(sec+pgEpochUnix, usec*1_000).UTC(), true +} + // Close closes the underlying database connection. func (c *Client) Close() error { if err := c.db.Close(); err != nil { diff --git a/explain/export_test.go b/explain/export_test.go new file mode 100644 index 0000000..08945d3 --- /dev/null +++ b/explain/export_test.go @@ -0,0 +1,11 @@ +package explain + +// Exported wrappers for internal symbols used in package-external tests. + +var ( + BuildAnyArgs = buildAnyArgs + ParseTimestampParams = parseTimestampParams + ParsePGTimestamp = parsePGTimestamp +) + +const PgEpochUnix = pgEpochUnix diff --git a/proxy/postgres/conn.go b/proxy/postgres/conn.go index e30a46d..0f7deb4 100644 --- a/proxy/postgres/conn.go +++ b/proxy/postgres/conn.go @@ -23,6 +23,15 @@ type encoder interface { Encode(dst []byte) ([]byte, error) } +// Timestamp type OIDs in the PostgreSQL type catalog. +const ( + oidTimestamp uint32 = 1114 + oidTimestampTZ uint32 = 1184 +) + +// pgEpochUnix is the Unix timestamp of PostgreSQL's internal epoch (2000-01-01 00:00:00 UTC). +const pgEpochUnix int64 = 946684800 + // conn manages bidirectional relay and protocol parsing for a single connection. type conn struct { client *pgproto.Backend // reads FrontendMessages from client @@ -33,10 +42,12 @@ type conn struct { events chan<- proxy.Event // Extended query state. - preparedStmts map[string]string // stmt name -> query - lastParse string // query from most recent Parse - lastBindArgs []string // args from most recent Bind - lastBindStmt string // stmt name from most recent Bind + preparedStmts map[string]string // stmt name -> query + preparedStmtOIDs map[string][]uint32 // stmt name -> parameter OIDs + lastParse string // query from most recent Parse + lastParamOIDs []uint32 // parameter OIDs from most recent Parse + lastBindArgs []string // args from most recent Bind + lastBindStmt string // stmt name from most recent Bind // Transaction tracking. activeTxID string @@ -48,10 +59,11 @@ type conn struct { func newConn(clientConn, upstreamConn net.Conn, events chan<- proxy.Event) *conn { return &conn{ - clientConn: clientConn, - upstreamConn: upstreamConn, - events: events, - preparedStmts: make(map[string]string), + clientConn: clientConn, + upstreamConn: upstreamConn, + events: events, + preparedStmts: make(map[string]string), + preparedStmtOIDs: make(map[string][]uint32), } } @@ -297,17 +309,29 @@ func (c *conn) handleSimpleQuery(m *pgproto.Query) { func (c *conn) handleParse(m *pgproto.Parse) { c.lastParse = m.Query + c.lastParamOIDs = m.ParameterOIDs if m.Name != "" { c.preparedStmts[m.Name] = m.Query + c.preparedStmtOIDs[m.Name] = m.ParameterOIDs } } func (c *conn) handleBind(m *pgproto.Bind) { c.lastBindStmt = m.PreparedStatement + paramOIDs := c.lastParamOIDs + if m.PreparedStatement != "" { + if oids, ok := c.preparedStmtOIDs[m.PreparedStatement]; ok { + paramOIDs = oids + } + } c.lastBindArgs = make([]string, len(m.Parameters)) for i, p := range m.Parameters { + oid := uint32(0) + if i < len(paramOIDs) { + oid = paramOIDs[i] + } if isBinaryFormat(m.ParameterFormatCodes, i) { - c.lastBindArgs[i] = decodeBinaryParam(p) + c.lastBindArgs[i] = decodeBinaryParam(p, oid) } else { c.lastBindArgs[i] = string(p) } @@ -327,8 +351,15 @@ func isBinaryFormat(codes []int16, i int) bool { } // decodeBinaryParam attempts to decode a binary-format parameter into a readable string. -// Without type OID information, we use the byte length as a heuristic for common types. -func decodeBinaryParam(p []byte) string { +// When oid identifies a timestamp type, the 8-byte PostgreSQL microsecond representation +// is decoded as an RFC3339 string so it can be used directly in parameterised queries +// (including EXPLAIN) without causing "date/time field value out of range" errors. +// For all other types, the byte length is used as a heuristic for common types. +func decodeBinaryParam(p []byte, oid uint32) string { + if (oid == oidTimestamp || oid == oidTimestampTZ) && len(p) == 8 { + microsecs := int64(binary.BigEndian.Uint64(p)) //nolint:gosec // interpreting as signed int64 + return decodePGTimestampMicros(microsecs) + } switch len(p) { case 1: // bool or int8 @@ -349,6 +380,18 @@ func decodeBinaryParam(p []byte) string { return string(p) } +// decodePGTimestampMicros converts a PostgreSQL binary timestamp (microseconds since +// 2000-01-01 00:00:00 UTC) to an RFC3339Nano string that PostgreSQL can parse back. +func decodePGTimestampMicros(microsecs int64) string { + sec := microsecs / 1_000_000 + usec := microsecs % 1_000_000 + if usec < 0 { + sec-- + usec += 1_000_000 + } + return time.Unix(sec+pgEpochUnix, usec*1_000).UTC().Format(time.RFC3339Nano) +} + func (c *conn) handleExecute() { q := c.lastParse if c.lastBindStmt != "" { diff --git a/proxy/postgres/conn_test.go b/proxy/postgres/conn_test.go new file mode 100644 index 0000000..a804e9c --- /dev/null +++ b/proxy/postgres/conn_test.go @@ -0,0 +1,55 @@ +package postgres_test + +import ( + "testing" + "time" + + pgproxy "github.com/mickamy/sql-tap/proxy/postgres" +) + +func TestDecodePGTimestampMicros(t *testing.T) { + t.Parallel() + + pgEpoch := time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC) + + tests := []struct { + name string + microsecs int64 + want time.Time + }{ + { + name: "zero (PostgreSQL epoch)", + microsecs: 0, + want: pgEpoch, + }, + { + name: "issue repro value (~2026-02-27)", + microsecs: 825505830505628, + want: pgEpoch.Add(time.Duration(825505830505628) * time.Microsecond), + }, + { + name: "negative (before 2000-01-01)", + microsecs: -1_000_000, + want: time.Date(1999, 12, 31, 23, 59, 59, 0, time.UTC), + }, + { + name: "negative fractional microsecond", + microsecs: -1, + want: time.Date(1999, 12, 31, 23, 59, 59, 999_999_000, time.UTC), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + got, err := time.Parse(time.RFC3339Nano, pgproxy.DecodePGTimestampMicros(tt.microsecs)) + if err != nil { + t.Fatalf("DecodePGTimestampMicros(%d) returned unparseable string: %v", tt.microsecs, err) + } + if !got.Equal(tt.want) { + t.Errorf("DecodePGTimestampMicros(%d) = %v, want %v", tt.microsecs, got, tt.want) + } + }) + } +} diff --git a/proxy/postgres/export_test.go b/proxy/postgres/export_test.go new file mode 100644 index 0000000..9bb3305 --- /dev/null +++ b/proxy/postgres/export_test.go @@ -0,0 +1,5 @@ +package postgres + +// Exported wrappers for internal symbols used in package-external tests. + +var DecodePGTimestampMicros = decodePGTimestampMicros