Skip to content

Commit 2083882

Browse files
committed
implement a basic audit log analyzer
Signed-off-by: Yang Keao <yangkeao@chunibyo.icu>
1 parent d9ef65a commit 2083882

File tree

4 files changed

+519
-13
lines changed

4 files changed

+519
-13
lines changed

cmd/auditloganalyzer/main.go

Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
// Copyright 2025 PingCAP, Inc.
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
package main
5+
6+
import (
7+
"database/sql"
8+
"encoding/csv"
9+
"fmt"
10+
"os"
11+
"time"
12+
13+
"github.com/pingcap/tiproxy/lib/config"
14+
"github.com/pingcap/tiproxy/lib/util/cmd"
15+
lg "github.com/pingcap/tiproxy/lib/util/logger"
16+
replaycmd "github.com/pingcap/tiproxy/pkg/sqlreplay/cmd"
17+
"github.com/pingcap/tiproxy/pkg/sqlreplay/replay"
18+
"github.com/pingcap/tiproxy/pkg/util/versioninfo"
19+
"github.com/spf13/cobra"
20+
"go.uber.org/zap"
21+
)
22+
23+
const (
24+
formatCSV = "csv"
25+
formatMySQL = "mysql"
26+
)
27+
28+
func main() {
29+
rootCmd := &cobra.Command{
30+
Use: os.Args[0],
31+
Short: "start the analyzer",
32+
Version: fmt.Sprintf("%s, commit %s", versioninfo.TiProxyVersion, versioninfo.TiProxyGitHash),
33+
}
34+
rootCmd.SetOut(os.Stdout)
35+
rootCmd.SetErr(os.Stderr)
36+
37+
input := rootCmd.PersistentFlags().String("input", "", "directory for traffic files")
38+
startTime := rootCmd.PersistentFlags().Time("start-time", time.Time{}, []string{time.RFC3339, time.RFC3339Nano}, "the start time to analyze the audit log.")
39+
endTime := rootCmd.PersistentFlags().Time("end-time", time.Time{}, []string{time.RFC3339, time.RFC3339Nano}, "the end time to analyze the audit log.")
40+
output := rootCmd.PersistentFlags().String("output", "audit_log_analysis_result.csv", "the output path for analysis result.")
41+
db := rootCmd.PersistentFlags().String("db", "", "the target database to analyze. Empty means all databases will be recorded.")
42+
filterCommandWithRetry := rootCmd.PersistentFlags().Bool("filter-command-with-retry", false, "filter out commands that are retries according to the audit log.")
43+
outputFormat := rootCmd.PersistentFlags().String("output-format", "csv", "the output format for analysis result. Currently only 'csv' and 'mysql' is supported.")
44+
outputTableName := rootCmd.PersistentFlags().String("output-table-name", "audit_log_analysis", "the output table name when output format is 'mysql'.")
45+
46+
rootCmd.RunE = func(cmd *cobra.Command, _ []string) error {
47+
logger, _, _, err := lg.BuildLogger(&config.Log{
48+
Encoder: "tidb",
49+
LogOnline: config.LogOnline{
50+
Level: "info",
51+
},
52+
})
53+
if err != nil {
54+
return err
55+
}
56+
57+
result, err := replay.Analyze(logger, replaycmd.AnalyzeConfig{
58+
Input: *input,
59+
Start: *startTime,
60+
End: *endTime,
61+
DB: *db,
62+
FilterCommandWithRetry: *filterCommandWithRetry,
63+
})
64+
if err != nil {
65+
return err
66+
}
67+
68+
switch *outputFormat {
69+
case formatCSV:
70+
logger.Info("writing analysis result to CSV", zap.String("output", *output))
71+
return writeAnalyzeResultToCSV(result, *output)
72+
case formatMySQL:
73+
logger.Info("writing analysis result to MySQL", zap.String("output", *output), zap.String("table", *outputTableName))
74+
return writeAnalyzeResultToMySQL(result, *output, *outputTableName)
75+
default:
76+
return fmt.Errorf("unsupported output format: %s", *outputFormat)
77+
}
78+
}
79+
80+
cmd.RunRootCommand(rootCmd)
81+
}
82+
83+
func writeAnalyzeResultToCSV(result replaycmd.AuditLogAnalyzeResult, outputPath string) error {
84+
f, err := os.Create(outputPath)
85+
if err != nil {
86+
return err
87+
}
88+
defer f.Close()
89+
w := csv.NewWriter(f)
90+
for sql, group := range result {
91+
dbAccessPatterns, err := group.DBAccessPatterns.MarshalJSON()
92+
if err != nil {
93+
return err
94+
}
95+
record := []string{
96+
sql,
97+
fmt.Sprintf("%d", group.ExecutionCount),
98+
fmt.Sprintf("%d", group.TotalCostTime.Microseconds()),
99+
fmt.Sprintf("%d", group.TotalAffectedRows),
100+
group.StmtTypes.String(),
101+
string(dbAccessPatterns),
102+
}
103+
if err := w.Write(record); err != nil {
104+
return err
105+
}
106+
}
107+
w.Flush()
108+
if err := w.Error(); err != nil {
109+
return err
110+
}
111+
return nil
112+
}
113+
114+
func writeAnalyzeResultToMySQL(result replaycmd.AuditLogAnalyzeResult, outputPath string, outputTableName string) error {
115+
db, err := sql.Open("mysql", outputPath)
116+
if err != nil {
117+
return err
118+
}
119+
defer db.Close()
120+
121+
createTableSQL := fmt.Sprintf(`CREATE TABLE IF NOT EXISTS %s (
122+
sql_text TEXT,
123+
execution_count INT,
124+
total_cost_time BIGINT,
125+
total_affected_rows BIGINT,
126+
statement_types TEXT
127+
db_access_patterns JSON
128+
);`, outputTableName)
129+
_, err = db.Exec(createTableSQL)
130+
if err != nil {
131+
return err
132+
}
133+
134+
insertSQL := fmt.Sprintf(`INSERT INTO %s (sql_text, execution_count, total_cost_time, total_affected_rows, statement_types, db_access_patterns) VALUES (?, ?, ?, ?, ?, ?)`, outputTableName)
135+
stmt, err := db.Prepare(insertSQL)
136+
if err != nil {
137+
return err
138+
}
139+
defer stmt.Close()
140+
141+
for sqlText, group := range result {
142+
dbAccessPatterns, err := group.DBAccessPatterns.MarshalJSON()
143+
if err != nil {
144+
return err
145+
}
146+
_, err = stmt.Exec(sqlText, group.ExecutionCount, group.TotalCostTime.Microseconds(), group.TotalAffectedRows, group.StmtTypes.String(), string(dbAccessPatterns))
147+
if err != nil {
148+
return err
149+
}
150+
}
151+
152+
return nil
153+
}
Lines changed: 244 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,244 @@
1+
// Copyright 2025 PingCAP, Inc.
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
package cmd
5+
6+
import (
7+
"encoding/json"
8+
"sort"
9+
"strconv"
10+
"strings"
11+
"time"
12+
13+
"github.com/pingcap/tidb/pkg/parser"
14+
"github.com/pingcap/tiproxy/lib/util/errors"
15+
pnet "github.com/pingcap/tiproxy/pkg/proxy/net"
16+
"github.com/siddontang/go/hack"
17+
)
18+
19+
type stmtTypesSet map[string]struct{}
20+
21+
func (s stmtTypesSet) String() string {
22+
var types []string
23+
for stmtType := range s {
24+
types = append(types, stmtType)
25+
}
26+
return strings.Join(types, ",")
27+
}
28+
29+
type dbAccessPatterns map[string]struct{}
30+
31+
func (d dbAccessPatterns) MarshalJSON() ([]byte, error) {
32+
var patterns [][]string
33+
for patternStr := range d {
34+
pattern := strings.Split(patternStr, ",")
35+
patterns = append(patterns, pattern)
36+
}
37+
38+
return json.Marshal(patterns)
39+
}
40+
41+
// AuditLogGroup is the analysis result for a group of similar audit log entries.
42+
type AuditLogGroup struct {
43+
ExecutionCount int
44+
TotalCostTime time.Duration
45+
TotalAffectedRows int64
46+
StmtTypes stmtTypesSet
47+
DBAccessPatterns dbAccessPatterns
48+
}
49+
50+
// AuditLogAnalyzeResult is the result of analyzing audit logs.
51+
type AuditLogAnalyzeResult map[string]AuditLogGroup
52+
53+
func (r AuditLogAnalyzeResult) Merge(other AuditLogAnalyzeResult) {
54+
for sql, group := range other {
55+
finalGroup := r[sql]
56+
finalGroup.ExecutionCount += group.ExecutionCount
57+
finalGroup.TotalCostTime += group.TotalCostTime
58+
finalGroup.TotalAffectedRows += group.TotalAffectedRows
59+
for stmtType := range group.StmtTypes {
60+
if finalGroup.StmtTypes == nil {
61+
finalGroup.StmtTypes = make(map[string]struct{})
62+
}
63+
finalGroup.StmtTypes[stmtType] = struct{}{}
64+
}
65+
for pattern := range group.DBAccessPatterns {
66+
if finalGroup.DBAccessPatterns == nil {
67+
finalGroup.DBAccessPatterns = make(map[string]struct{})
68+
}
69+
finalGroup.DBAccessPatterns[pattern] = struct{}{}
70+
}
71+
r[sql] = finalGroup
72+
}
73+
}
74+
75+
// AnalyzeConfig is the configuration for audit log analysis.
76+
type AnalyzeConfig struct {
77+
Input string
78+
Start time.Time
79+
End time.Time
80+
DB string
81+
FilterCommandWithRetry bool
82+
}
83+
84+
type auditLogAnalyzer struct {
85+
reader LineReader
86+
87+
cfg AnalyzeConfig
88+
connInfo map[uint64]auditLogPluginConnCtx
89+
}
90+
91+
// NewAuditLogAnalyzer creates a new audit log analyzer.
92+
func NewAuditLogAnalyzer(reader LineReader, cfg AnalyzeConfig) *auditLogAnalyzer {
93+
return &auditLogAnalyzer{
94+
reader: reader,
95+
cfg: cfg,
96+
connInfo: make(map[uint64]auditLogPluginConnCtx),
97+
}
98+
}
99+
100+
// Analyze analyzes the audit log and returns the analysis result.
101+
func (a *auditLogAnalyzer) Analyze() (AuditLogAnalyzeResult, error) {
102+
result := make(AuditLogAnalyzeResult)
103+
104+
kvs := make(map[string]string, 25)
105+
for {
106+
line, filename, lineIdx, err := a.reader.ReadLine()
107+
if err != nil {
108+
return result, err
109+
}
110+
clear(kvs)
111+
err = parseLog(kvs, hack.String(line))
112+
if err != nil {
113+
return result, errors.Errorf("%s, line %d: %s", filename, lineIdx, err.Error())
114+
}
115+
// Only analyze the COMPLETED event
116+
event, ok := kvs[auditPluginKeyEvent]
117+
if !ok || event != auditPluginEventEnd {
118+
continue
119+
}
120+
121+
// Only analyze the event within the time range
122+
startTs, endTs, err := parseStartAndEndTs(kvs)
123+
if err != nil {
124+
return nil, errors.Wrapf(err, "%s, line %d", filename, lineIdx)
125+
}
126+
if endTs.Before(a.cfg.Start) {
127+
continue
128+
}
129+
if endTs.After(a.cfg.End) {
130+
// Reach the end time, stop analyzing.
131+
return result, nil
132+
}
133+
134+
// Only analyze the `Query` and `Execute` commands
135+
cmdStr := parseCommand(kvs[auditPluginKeyCommand])
136+
if cmdStr != "Query" && cmdStr != "Execute" {
137+
continue
138+
}
139+
140+
// Only analyze the SQL in given database
141+
if len(a.cfg.DB) != 0 {
142+
databases, ok := kvs[auditPluginKeyDatabases]
143+
if !ok {
144+
continue
145+
}
146+
147+
includeTargetDB := false
148+
for _, db := range strings.Split(databases, ",") {
149+
if db == a.cfg.DB {
150+
includeTargetDB = true
151+
}
152+
}
153+
if !includeTargetDB {
154+
continue
155+
}
156+
}
157+
158+
// Try to filter out retried commands
159+
connID, err := strconv.ParseUint(kvs[auditPluginKeyConnID], 10, 64)
160+
if err != nil {
161+
return result, errors.Wrapf(err, "%s, line %d: parse conn id failed: %s", filename, lineIdx, kvs[auditPluginKeyConnID])
162+
}
163+
connInfo := a.connInfo[connID]
164+
if a.cfg.FilterCommandWithRetry {
165+
if retryStr, ok := kvs[auditPluginKeyRetry]; ok {
166+
// If it's a retry command, just skip it.
167+
if retryStr == "true" {
168+
continue
169+
}
170+
}
171+
} else {
172+
sql, err := parseSQL(kvs[auditPluginKeySQL])
173+
if err != nil {
174+
return result, errors.Wrapf(err, "%s, line %d: unquote sql failed: %s", filename, lineIdx, kvs[auditPluginKeySQL])
175+
}
176+
if isDuplicatedWrite(connInfo.lastCmd, kvs, cmdStr, sql, startTs, endTs) {
177+
continue
178+
}
179+
}
180+
181+
sql, err := parseSQL(kvs[auditPluginKeySQL])
182+
if err != nil {
183+
return result, errors.Wrapf(err, "unquote sql failed: %s", kvs[auditPluginKeySQL])
184+
}
185+
normalizedSQL := parser.Normalize(sql, "ON")
186+
group := result[normalizedSQL]
187+
188+
var costTime time.Duration
189+
costTimeStr := kvs[auditPluginKeyCostTime]
190+
if len(costTimeStr) != 0 {
191+
millis, err := strconv.ParseFloat(costTimeStr, 32)
192+
if err != nil {
193+
return result, errors.Errorf("parsing cost time failed: %s", costTimeStr)
194+
}
195+
costTime = time.Duration(millis) * (time.Millisecond)
196+
}
197+
198+
var affectedRows int64
199+
affectedRowsStr := kvs[auditPluginKeyRows]
200+
if len(affectedRowsStr) != 0 {
201+
affectedRows, err = strconv.ParseInt(affectedRowsStr, 10, 64)
202+
if err != nil {
203+
return result, errors.Errorf("parsing affected rows failed: %s", affectedRowsStr)
204+
}
205+
}
206+
207+
// Record the last command info for deduplication. We only recorded the needed fields here.
208+
connInfo.lastCmd = &Command{
209+
StartTs: startTs,
210+
EndTs: endTs,
211+
ConnID: connID,
212+
}
213+
switch cmdStr {
214+
case "Query":
215+
connInfo.lastCmd.Type = pnet.ComQuery
216+
connInfo.lastCmd.Payload = append([]byte{pnet.ComQuery.Byte()}, hack.Slice(sql)...)
217+
case "Execute":
218+
connInfo.lastCmd.Type = pnet.ComStmtExecute
219+
connInfo.lastCmd.PreparedStmt = sql
220+
}
221+
connInfo.lastCmd.StmtType = kvs[auditPluginKeyStmtType]
222+
connInfo.lastCmd.kvs = kvs
223+
a.connInfo[connID] = connInfo
224+
225+
group.ExecutionCount++
226+
group.TotalCostTime += costTime
227+
group.TotalAffectedRows += affectedRows
228+
if len(kvs[auditPluginKeyStmtType]) != 0 {
229+
if group.StmtTypes == nil {
230+
group.StmtTypes = make(map[string]struct{})
231+
}
232+
group.StmtTypes[kvs[auditPluginKeyStmtType]] = struct{}{}
233+
}
234+
if len(kvs[auditPluginKeyDatabases]) != 0 {
235+
if group.DBAccessPatterns == nil {
236+
group.DBAccessPatterns = make(map[string]struct{})
237+
}
238+
dbs := strings.Split(kvs[auditPluginKeyDatabases], ",")
239+
sort.StringSlice(dbs).Sort()
240+
group.DBAccessPatterns[strings.Join(dbs, ",")] = struct{}{}
241+
}
242+
result[normalizedSQL] = group
243+
}
244+
}

0 commit comments

Comments
 (0)