Skip to content

Commit

Permalink
Fix prepared statement caching bug.
Browse files Browse the repository at this point in the history
It wasn't caching the prepared statement by keyspace, so if you performed
identical queries on two different keyspaces, it would use the prepared
statement from the wrong keyspace (and end up querying against the wrong
keyspace).

I added a "currentKeyspace" field to the Conn struct so the statement caching
code has it available to use as part of the cache key.
  • Loading branch information
Muir Manders committed Aug 13, 2014
1 parent b8692cd commit 763d85a
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 36 deletions.
108 changes: 83 additions & 25 deletions cassandra_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,34 +51,45 @@ func createTable(s *Session, table string) error {
return err
}

func createSession(tb testing.TB) *Session {
func createCluster() *ClusterConfig {
cluster := NewCluster(clusterHosts...)
cluster.ProtoVersion = *flagProto
cluster.CQLVersion = *flagCQL
cluster.Timeout = 5 * time.Second
cluster.Consistency = Quorum
cluster.RetryPolicy.NumRetries = *flagRetry

return cluster
}

func createKeyspace(tb testing.TB, cluster *ClusterConfig, keyspace string) {
session, err := cluster.CreateSession()
if err != nil {
tb.Fatal("createSession:", err)
}
if err = session.Query(`DROP KEYSPACE ` + keyspace).Exec(); err != nil {
tb.Log("drop keyspace:", err)
}
if err := session.Query(fmt.Sprintf(`CREATE KEYSPACE %s
WITH replication = {
'class' : 'SimpleStrategy',
'replication_factor' : %d
}`, keyspace, *flagRF)).Consistency(All).Exec(); err != nil {
tb.Fatalf("error creating keyspace %s: %v", keyspace, err)
}
tb.Logf("Created keyspace %s", keyspace)
session.Close()
}

func createSession(tb testing.TB) *Session {
cluster := createCluster()

// Drop and re-create the keyspace once. Different tests should use their own
// individual tables, but can assume that the table does not exist before.
initOnce.Do(func() {
session, err := cluster.CreateSession()
if err != nil {
tb.Fatal("createSession:", err)
}
// Drop and re-create the keyspace once. Different tests should use their own
// individual tables, but can assume that the table does not exist before.
if err := session.Query(`DROP KEYSPACE gocql_test`).Exec(); err != nil {
tb.Log("drop keyspace:", err)
}
if err := session.Query(fmt.Sprintf(`CREATE KEYSPACE gocql_test
WITH replication = {
'class' : 'SimpleStrategy',
'replication_factor' : %d
}`, *flagRF)).Consistency(All).Exec(); err != nil {
tb.Fatal("create keyspace:", err)
}
tb.Log("Created keyspace")
session.Close()
createKeyspace(tb, cluster, "gocql_test")
})

cluster.Keyspace = "gocql_test"
session, err := cluster.CreateSession()
if err != nil {
Expand Down Expand Up @@ -941,19 +952,19 @@ func TestPreparedCacheEviction(t *testing.T) {
//Walk through all the configured hosts and test cache retention and eviction
var selFound, insFound, updFound, delFound, selEvict bool
for i := range session.cfg.Hosts {
_, ok := stmtsLRU.lru.Get(session.cfg.Hosts[i] + ":9042SELECT id,mod FROM prepcachetest WHERE id = 1")
_, ok := stmtsLRU.lru.Get(session.cfg.Hosts[i] + ":9042gocql_testSELECT id,mod FROM prepcachetest WHERE id = 1")
selFound = selFound || ok

_, ok = stmtsLRU.lru.Get(session.cfg.Hosts[i] + ":9042INSERT INTO prepcachetest (id,mod) VALUES (?, ?)")
_, ok = stmtsLRU.lru.Get(session.cfg.Hosts[i] + ":9042gocql_testINSERT INTO prepcachetest (id,mod) VALUES (?, ?)")
insFound = insFound || ok

_, ok = stmtsLRU.lru.Get(session.cfg.Hosts[i] + ":9042UPDATE prepcachetest SET mod = ? WHERE id = ?")
_, ok = stmtsLRU.lru.Get(session.cfg.Hosts[i] + ":9042gocql_testUPDATE prepcachetest SET mod = ? WHERE id = ?")
updFound = updFound || ok

_, ok = stmtsLRU.lru.Get(session.cfg.Hosts[i] + ":9042DELETE FROM prepcachetest WHERE id = ?")
_, ok = stmtsLRU.lru.Get(session.cfg.Hosts[i] + ":9042gocql_testDELETE FROM prepcachetest WHERE id = ?")
delFound = delFound || ok

_, ok = stmtsLRU.lru.Get(session.cfg.Hosts[i] + ":9042SELECT id,mod FROM prepcachetest WHERE id = 0")
_, ok = stmtsLRU.lru.Get(session.cfg.Hosts[i] + ":9042gocql_testSELECT id,mod FROM prepcachetest WHERE id = 0")
selEvict = selEvict || !ok

}
Expand All @@ -974,6 +985,53 @@ func TestPreparedCacheEviction(t *testing.T) {
}
}

func TestPreparedCacheKey(t *testing.T) {
session := createSession(t)
defer session.Close()

// create a second keyspace
cluster2 := createCluster()
createKeyspace(t, cluster2, "gocql_test2")
cluster2.Keyspace = "gocql_test2"
session2, err := cluster2.CreateSession()
if err != nil {
t.Fatal("create session:", err)
}
defer session2.Close()

// both keyspaces have a table named "test_stmt_cache_key"
if err := createTable(session, "CREATE TABLE test_stmt_cache_key (id varchar primary key, field varchar)"); err != nil {
t.Fatal("create table:", err)
}
if err := createTable(session2, "CREATE TABLE test_stmt_cache_key (id varchar primary key, field varchar)"); err != nil {
t.Fatal("create table:", err)
}

// both tables have a single row with the same partition key but different column value
if err = session.Query(`INSERT INTO test_stmt_cache_key (id, field) VALUES (?, ?)`, "key", "one").Exec(); err != nil {
t.Fatal("insert:", err)
}
if err = session2.Query(`INSERT INTO test_stmt_cache_key (id, field) VALUES (?, ?)`, "key", "two").Exec(); err != nil {
t.Fatal("insert:", err)
}

// should be able to see different values in each keyspace
var value string
if err = session.Query("SELECT field FROM test_stmt_cache_key WHERE id = ?", "key").Scan(&value); err != nil {
t.Fatal("select:", err)
}
if value != "one" {
t.Errorf("Expected one, got %s", value)
}

if err = session2.Query("SELECT field FROM test_stmt_cache_key WHERE id = ?", "key").Scan(&value); err != nil {
t.Fatal("select:", err)
}
if value != "two" {
t.Errorf("Expected two, got %s", value)
}
}

//TestMarshalFloat64Ptr tests to see that a pointer to a float64 is marshalled correctly.
func TestMarshalFloat64Ptr(t *testing.T) {
session := createSession(t)
Expand Down Expand Up @@ -1051,7 +1109,7 @@ func TestVarint(t *testing.T) {

err := session.Query("SELECT test FROM varint_test").Scan(&result64)
if err == nil || strings.Index(err.Error(), "out of range") == -1 {
t.Errorf("expected our of range error since value is too big for int64")
t.Errorf("expected out of range error since value is too big for int64")
}

// value not set in cassandra, leave bind variable empty
Expand Down
30 changes: 19 additions & 11 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,12 @@ type Conn struct {
calls []callReq
nwait int32

pool ConnectionPool
compressor Compressor
auth Authenticator
addr string
version uint8
pool ConnectionPool
compressor Compressor
auth Authenticator
addr string
version uint8
currentKeyspace string

closedMu sync.RWMutex
isClosed bool
Expand Down Expand Up @@ -310,7 +311,10 @@ func (c *Conn) ping() error {

func (c *Conn) prepareStatement(stmt string, trace Tracer) (*QueryInfo, error) {
stmtsLRU.mu.Lock()
if val, ok := stmtsLRU.lru.Get(c.addr + stmt); ok {

stmtCacheKey := c.addr + c.currentKeyspace + stmt

if val, ok := stmtsLRU.lru.Get(stmtCacheKey); ok {
flight := val.(*inflightPrepare)
stmtsLRU.mu.Unlock()
flight.wg.Wait()
Expand All @@ -319,7 +323,7 @@ func (c *Conn) prepareStatement(stmt string, trace Tracer) (*QueryInfo, error) {

flight := new(inflightPrepare)
flight.wg.Add(1)
stmtsLRU.lru.Add(c.addr+stmt, flight)
stmtsLRU.lru.Add(stmtCacheKey, flight)
stmtsLRU.mu.Unlock()

resp, err := c.exec(&prepareFrame{Stmt: stmt}, trace)
Expand All @@ -345,7 +349,7 @@ func (c *Conn) prepareStatement(stmt string, trace Tracer) (*QueryInfo, error) {

if err != nil {
stmtsLRU.mu.Lock()
stmtsLRU.lru.Remove(c.addr + stmt)
stmtsLRU.lru.Remove(stmtCacheKey)
stmtsLRU.mu.Unlock()
}

Expand Down Expand Up @@ -414,8 +418,9 @@ func (c *Conn) executeQuery(qry *Query) *Iter {
return &Iter{}
case RequestErrUnprepared:
stmtsLRU.mu.Lock()
if _, ok := stmtsLRU.lru.Get(c.addr + qry.stmt); ok {
stmtsLRU.lru.Remove(c.addr + qry.stmt)
stmtCacheKey := c.addr + c.currentKeyspace + qry.stmt
if _, ok := stmtsLRU.lru.Get(stmtCacheKey); ok {
stmtsLRU.lru.Remove(stmtCacheKey)
stmtsLRU.mu.Unlock()
return c.executeQuery(qry)
}
Expand Down Expand Up @@ -470,6 +475,9 @@ func (c *Conn) UseKeyspace(keyspace string) error {
default:
return NewErrProtocol("Unknown type in response to USE: %s", x)
}

c.currentKeyspace = keyspace

return nil
}

Expand Down Expand Up @@ -537,7 +545,7 @@ func (c *Conn) executeBatch(batch *Batch) error {
stmt, found := stmts[string(x.StatementId)]
if found {
stmtsLRU.mu.Lock()
stmtsLRU.lru.Remove(c.addr + stmt)
stmtsLRU.lru.Remove(c.addr + c.currentKeyspace + stmt)
stmtsLRU.mu.Unlock()
}
if found {
Expand Down

0 comments on commit 763d85a

Please sign in to comment.