GoToSocial/vendor/github.com/uptrace/opentelemetry-go-extra/otelsql/driver.go
2023-05-09 18:19:48 +01:00

461 lines
12 KiB
Go

package otelsql
import (
"context"
"database/sql"
"database/sql/driver"
"errors"
"go.opentelemetry.io/otel/trace"
)
// Open is a wrapper over sql.Open that instruments the sql.DB to record executed queries
// using OpenTelemetry API.
func Open(driverName, dsn string, opts ...Option) (*sql.DB, error) {
db, err := sql.Open(driverName, dsn)
if err != nil {
return nil, err
}
return patchDB(db, dsn, opts...)
}
func patchDB(db *sql.DB, dsn string, opts ...Option) (*sql.DB, error) {
dbDriver := db.Driver()
d := newDriver(dbDriver, opts)
if _, ok := dbDriver.(driver.DriverContext); ok {
connector, err := d.OpenConnector(dsn)
if err != nil {
return nil, err
}
return sqlOpenDB(connector, d.instrum), nil
}
return sqlOpenDB(&dsnConnector{
driver: d,
dsn: dsn,
}, d.instrum), nil
}
// OpenDB is a wrapper over sql.OpenDB that instruments the sql.DB to record executed queries
// using OpenTelemetry API.
func OpenDB(connector driver.Connector, opts ...Option) *sql.DB {
instrum := newDBInstrum(opts)
c := newConnector(connector.Driver(), connector, instrum)
return sqlOpenDB(c, instrum)
}
func sqlOpenDB(connector driver.Connector, instrum *dbInstrum) *sql.DB {
db := sql.OpenDB(connector)
ReportDBStatsMetrics(db, WithMeterProvider(instrum.meterProvider), WithAttributes(instrum.attrs...))
return db
}
type dsnConnector struct {
driver *otelDriver
dsn string
}
func (c *dsnConnector) Connect(ctx context.Context) (driver.Conn, error) {
var conn driver.Conn
err := c.driver.instrum.withSpan(ctx, "db.Connect", "",
func(ctx context.Context, span trace.Span) error {
var err error
conn, err = c.driver.Open(c.dsn)
return err
})
return conn, err
}
func (c *dsnConnector) Driver() driver.Driver {
return c.driver
}
//------------------------------------------------------------------------------
type otelDriver struct {
driver driver.Driver
driverCtx driver.DriverContext
instrum *dbInstrum
}
var _ driver.DriverContext = (*otelDriver)(nil)
func newDriver(dr driver.Driver, opts []Option) *otelDriver {
driverCtx, _ := dr.(driver.DriverContext)
d := &otelDriver{
driver: dr,
driverCtx: driverCtx,
instrum: newDBInstrum(opts),
}
return d
}
func (d *otelDriver) Open(name string) (driver.Conn, error) {
conn, err := d.driver.Open(name)
if err != nil {
return nil, err
}
return newConn(conn, d.instrum), nil
}
func (d *otelDriver) OpenConnector(dsn string) (driver.Connector, error) {
connector, err := d.driverCtx.OpenConnector(dsn)
if err != nil {
return nil, err
}
return newConnector(d, connector, d.instrum), nil
}
//------------------------------------------------------------------------------
type connector struct {
driver.Connector
driver driver.Driver
instrum *dbInstrum
}
var _ driver.Connector = (*connector)(nil)
func newConnector(d driver.Driver, c driver.Connector, instrum *dbInstrum) *connector {
return &connector{
driver: d,
Connector: c,
instrum: instrum,
}
}
func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
var conn driver.Conn
if err := c.instrum.withSpan(ctx, "db.Connect", "",
func(ctx context.Context, span trace.Span) error {
var err error
conn, err = c.Connector.Connect(ctx)
return err
}); err != nil {
return nil, err
}
return newConn(conn, c.instrum), nil
}
func (c *connector) Driver() driver.Driver {
return c.driver
}
//------------------------------------------------------------------------------
type otelConn struct {
driver.Conn
instrum *dbInstrum
ping pingFunc
exec execFunc
execCtx execCtxFunc
query queryFunc
queryCtx queryCtxFunc
prepareCtx prepareCtxFunc
beginTx beginTxFunc
resetSession resetSessionFunc
checkNamedValue checkNamedValueFunc
}
var _ driver.Conn = (*otelConn)(nil)
func newConn(conn driver.Conn, instrum *dbInstrum) *otelConn {
cn := &otelConn{
Conn: conn,
instrum: instrum,
}
cn.ping = cn.createPingFunc(conn)
cn.exec = cn.createExecFunc(conn)
cn.execCtx = cn.createExecCtxFunc(conn)
cn.query = cn.createQueryFunc(conn)
cn.queryCtx = cn.createQueryCtxFunc(conn)
cn.prepareCtx = cn.createPrepareCtxFunc(conn)
cn.beginTx = cn.createBeginTxFunc(conn)
cn.resetSession = cn.createResetSessionFunc(conn)
cn.checkNamedValue = cn.createCheckNamedValueFunc(conn)
return cn
}
var _ driver.Pinger = (*otelConn)(nil)
func (c *otelConn) Ping(ctx context.Context) error {
return c.ping(ctx)
}
type pingFunc func(ctx context.Context) error
func (c *otelConn) createPingFunc(conn driver.Conn) pingFunc {
if pinger, ok := conn.(driver.Pinger); ok {
return func(ctx context.Context) error {
return c.instrum.withSpan(ctx, "db.Ping", "",
func(ctx context.Context, span trace.Span) error {
return pinger.Ping(ctx)
})
}
}
return func(ctx context.Context) error {
return driver.ErrSkip
}
}
//------------------------------------------------------------------------------
var _ driver.Execer = (*otelConn)(nil)
func (c *otelConn) Exec(query string, args []driver.Value) (driver.Result, error) {
return c.exec(query, args)
}
type execFunc func(query string, args []driver.Value) (driver.Result, error)
func (c *otelConn) createExecFunc(conn driver.Conn) execFunc {
if execer, ok := conn.(driver.Execer); ok {
return func(query string, args []driver.Value) (driver.Result, error) {
return execer.Exec(query, args)
}
}
return func(query string, args []driver.Value) (driver.Result, error) {
return nil, driver.ErrSkip
}
}
//------------------------------------------------------------------------------
var _ driver.ExecerContext = (*otelConn)(nil)
func (c *otelConn) ExecContext(
ctx context.Context, query string, args []driver.NamedValue,
) (driver.Result, error) {
return c.execCtx(ctx, query, args)
}
type execCtxFunc func(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error)
func (c *otelConn) createExecCtxFunc(conn driver.Conn) execCtxFunc {
var fn execCtxFunc
if execer, ok := conn.(driver.ExecerContext); ok {
fn = execer.ExecContext
} else {
fn = func(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
vArgs, err := namedValueToValue(args)
if err != nil {
return nil, err
}
return c.exec(query, vArgs)
}
}
return func(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
var res driver.Result
if err := c.instrum.withSpan(ctx, "db.Exec", query,
func(ctx context.Context, span trace.Span) error {
var err error
res, err = fn(ctx, query, args)
if err != nil {
return err
}
if span.IsRecording() {
rows, err := res.RowsAffected()
if err == nil {
span.SetAttributes(dbRowsAffected.Int64(rows))
}
}
return nil
}); err != nil {
return nil, err
}
return res, nil
}
}
//------------------------------------------------------------------------------
var _ driver.Queryer = (*otelConn)(nil)
func (c *otelConn) Query(query string, args []driver.Value) (driver.Rows, error) {
return c.query(query, args)
}
type queryFunc func(query string, args []driver.Value) (driver.Rows, error)
func (c *otelConn) createQueryFunc(conn driver.Conn) queryFunc {
if queryer, ok := c.Conn.(driver.Queryer); ok {
return func(query string, args []driver.Value) (driver.Rows, error) {
return queryer.Query(query, args)
}
}
return func(query string, args []driver.Value) (driver.Rows, error) {
return nil, driver.ErrSkip
}
}
//------------------------------------------------------------------------------
var _ driver.QueryerContext = (*otelConn)(nil)
func (c *otelConn) QueryContext(
ctx context.Context, query string, args []driver.NamedValue,
) (driver.Rows, error) {
return c.queryCtx(ctx, query, args)
}
type queryCtxFunc func(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error)
func (c *otelConn) createQueryCtxFunc(conn driver.Conn) queryCtxFunc {
var fn queryCtxFunc
if queryer, ok := c.Conn.(driver.QueryerContext); ok {
fn = queryer.QueryContext
} else {
fn = func(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
vArgs, err := namedValueToValue(args)
if err != nil {
return nil, err
}
return c.query(query, vArgs)
}
}
return func(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
var rows driver.Rows
err := c.instrum.withSpan(ctx, "db.Query", query,
func(ctx context.Context, span trace.Span) error {
var err error
rows, err = fn(ctx, query, args)
return err
})
return rows, err
}
}
//------------------------------------------------------------------------------
var _ driver.ConnPrepareContext = (*otelConn)(nil)
func (c *otelConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
return c.prepareCtx(ctx, query)
}
type prepareCtxFunc func(ctx context.Context, query string) (driver.Stmt, error)
func (c *otelConn) createPrepareCtxFunc(conn driver.Conn) prepareCtxFunc {
var fn prepareCtxFunc
if preparer, ok := c.Conn.(driver.ConnPrepareContext); ok {
fn = preparer.PrepareContext
} else {
fn = func(ctx context.Context, query string) (driver.Stmt, error) {
return c.Conn.Prepare(query)
}
}
return func(ctx context.Context, query string) (driver.Stmt, error) {
var stmt driver.Stmt
if err := c.instrum.withSpan(ctx, "db.Prepare", query,
func(ctx context.Context, span trace.Span) error {
var err error
stmt, err = fn(ctx, query)
return err
}); err != nil {
return nil, err
}
return newStmt(stmt, query, c.instrum), nil
}
}
var _ driver.ConnBeginTx = (*otelConn)(nil)
func (c *otelConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
return c.beginTx(ctx, opts)
}
type beginTxFunc func(ctx context.Context, opts driver.TxOptions) (driver.Tx, error)
func (c *otelConn) createBeginTxFunc(conn driver.Conn) beginTxFunc {
var fn beginTxFunc
if txor, ok := conn.(driver.ConnBeginTx); ok {
fn = txor.BeginTx
} else {
fn = func(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
return conn.Begin()
}
}
return func(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
var tx driver.Tx
if err := c.instrum.withSpan(ctx, "db.Begin", "",
func(ctx context.Context, span trace.Span) error {
var err error
tx, err = fn(ctx, opts)
return err
}); err != nil {
return nil, err
}
return newTx(ctx, tx, c.instrum), nil
}
}
//------------------------------------------------------------------------------
var _ driver.SessionResetter = (*otelConn)(nil)
func (c *otelConn) ResetSession(ctx context.Context) error {
return c.resetSession(ctx)
}
type resetSessionFunc func(ctx context.Context) error
func (c *otelConn) createResetSessionFunc(conn driver.Conn) resetSessionFunc {
if resetter, ok := c.Conn.(driver.SessionResetter); ok {
return func(ctx context.Context) error {
return resetter.ResetSession(ctx)
}
}
return func(ctx context.Context) error {
return driver.ErrSkip
}
}
//------------------------------------------------------------------------------
var _ driver.NamedValueChecker = (*otelConn)(nil)
func (c *otelConn) CheckNamedValue(value *driver.NamedValue) error {
return c.checkNamedValue(value)
}
type checkNamedValueFunc func(*driver.NamedValue) error
func (c *otelConn) createCheckNamedValueFunc(conn driver.Conn) checkNamedValueFunc {
if checker, ok := c.Conn.(driver.NamedValueChecker); ok {
return func(value *driver.NamedValue) error {
return checker.CheckNamedValue(value)
}
}
return func(value *driver.NamedValue) error {
return driver.ErrSkip
}
}
//------------------------------------------------------------------------------
func namedValueToValue(named []driver.NamedValue) ([]driver.Value, error) {
args := make([]driver.Value, len(named))
for n, param := range named {
if len(param.Name) > 0 {
return nil, errors.New("otelsql: driver does not support named parameters")
}
args[n] = param.Value
}
return args, nil
}