Skip to content

Commit

Permalink
Merge pull request #77 from stefantds/st/support-execer-context
Browse files Browse the repository at this point in the history
Support context interfaces for Exec and Query
  • Loading branch information
shogo82148 authored Oct 16, 2021
2 parents 8ec7bf3 + 0580e9d commit 2677a07
Showing 1 changed file with 11 additions and 9 deletions.
20 changes: 11 additions & 9 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -189,13 +189,14 @@ func (conn *Conn) Exec(query string, args []driver.Value) (driver.Result, error)
panic("not supported")
}

// ExecContext calls the original Exec method of the connection.
// ExecContext calls the original ExecContext (or Exec as a fallback) method of the connection.
// It will trigger PreExec, Exec, PostExec hooks.
//
// If the original connection does not satisfy "database/sql/driver".Execer, it return ErrSkip error.
// If the original connection does not satisfy "database/sql/driver".ExecerContext nor "database/sql/driver".Execer, it return ErrSkip error.
func (conn *Conn) ExecContext(c context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
execer, ok := conn.Conn.(driver.Execer)
if !ok {
execer, exOk := conn.Conn.(driver.Execer)
execerCtx, exCtxOk := conn.Conn.(driver.ExecerContext)
if !exOk && !exCtxOk {
return nil, driver.ErrSkip
}

Expand All @@ -217,7 +218,7 @@ func (conn *Conn) ExecContext(c context.Context, query string, args []driver.Nam
}

// call the original method.
if execerCtx, ok := execer.(driver.ExecerContext); ok {
if execerCtx != nil {
result, err = execerCtx.ExecContext(c, stmt.QueryString, args)
} else {
select {
Expand Down Expand Up @@ -256,10 +257,11 @@ func (conn *Conn) Query(query string, args []driver.Value) (driver.Rows, error)
// QueryContext executes a query that may return rows.
// It wil trigger PreQuery, Query, PostQuery hooks.
//
// If the original connection does not satisfy "database/sql/driver".Queryer, it return ErrSkip error.
// If the original connection does not satisfy "database/sql/driver".QueryerContext nor "database/sql/driver".Queryer, it return ErrSkip error.
func (conn *Conn) QueryContext(c context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
queryer, ok := conn.Conn.(driver.Queryer)
if !ok {
queryer, qok := conn.Conn.(driver.Queryer)
queryerCtx, qCtxOk := conn.Conn.(driver.QueryerContext)
if !qok && !qCtxOk {
return nil, driver.ErrSkip
}

Expand All @@ -280,7 +282,7 @@ func (conn *Conn) QueryContext(c context.Context, query string, args []driver.Na
}

// call the original method.
if queryerCtx, ok := conn.Conn.(driver.QueryerContext); ok {
if queryerCtx != nil {
rows, err = queryerCtx.QueryContext(c, stmt.QueryString, args)
} else {
select {
Expand Down

0 comments on commit 2677a07

Please sign in to comment.