GoToSocial/vendor/github.com/uptrace/opentelemetry-go-extra/otelsql/stmt.go
2023-05-09 18:19:48 +01:00

121 lines
2.8 KiB
Go

package otelsql
import (
"context"
"database/sql/driver"
"go.opentelemetry.io/otel/trace"
)
type otelStmt struct {
driver.Stmt
query string
instrum *dbInstrum
execCtx stmtExecCtxFunc
queryCtx stmtQueryCtxFunc
}
var _ driver.Stmt = (*otelStmt)(nil)
func newStmt(stmt driver.Stmt, query string, instrum *dbInstrum) *otelStmt {
s := &otelStmt{
Stmt: stmt,
query: query,
instrum: instrum,
}
s.execCtx = s.createExecCtxFunc(stmt)
s.queryCtx = s.createQueryCtxFunc(stmt)
return s
}
//------------------------------------------------------------------------------
var _ driver.StmtExecContext = (*otelStmt)(nil)
func (stmt *otelStmt) ExecContext(
ctx context.Context, args []driver.NamedValue,
) (driver.Result, error) {
return stmt.execCtx(ctx, args)
}
type stmtExecCtxFunc func(ctx context.Context, args []driver.NamedValue) (driver.Result, error)
func (s *otelStmt) createExecCtxFunc(stmt driver.Stmt) stmtExecCtxFunc {
var fn stmtExecCtxFunc
if execer, ok := s.Stmt.(driver.StmtExecContext); ok {
fn = execer.ExecContext
} else {
fn = func(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
vArgs, err := namedValueToValue(args)
if err != nil {
return nil, err
}
return stmt.Exec(vArgs)
}
}
return func(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
var res driver.Result
err := s.instrum.withSpan(ctx, "stmt.Exec", s.query,
func(ctx context.Context, span trace.Span) error {
var err error
res, err = fn(ctx, args)
if err != nil {
return err
}
if span.IsRecording() {
rows, err := res.RowsAffected()
if err == nil {
span.SetAttributes(dbRowsAffected.Int64(rows))
}
}
return nil
})
return res, err
}
}
//------------------------------------------------------------------------------
var _ driver.StmtQueryContext = (*otelStmt)(nil)
func (stmt *otelStmt) QueryContext(
ctx context.Context, args []driver.NamedValue,
) (driver.Rows, error) {
return stmt.queryCtx(ctx, args)
}
type stmtQueryCtxFunc func(ctx context.Context, args []driver.NamedValue) (driver.Rows, error)
func (s *otelStmt) createQueryCtxFunc(stmt driver.Stmt) stmtQueryCtxFunc {
var fn stmtQueryCtxFunc
if queryer, ok := s.Stmt.(driver.StmtQueryContext); ok {
fn = queryer.QueryContext
} else {
fn = func(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
vArgs, err := namedValueToValue(args)
if err != nil {
return nil, err
}
return s.Query(vArgs)
}
}
return func(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
var rows driver.Rows
err := s.instrum.withSpan(ctx, "stmt.Query", s.query,
func(ctx context.Context, span trace.Span) error {
var err error
rows, err = fn(ctx, args)
return err
})
return rows, err
}
}