From 763d85a2a198d0fd5f0b1d5d1461d757e351a20e Mon Sep 17 00:00:00 2001 From: Muir Manders Date: Mon, 11 Aug 2014 16:36:45 -0700 Subject: [PATCH] Fix prepared statement caching bug. It wasn't caching the prepared statement by keyspace, so if you performed identical queries on two different keyspaces, it would use the prepared statement from the wrong keyspace (and end up querying against the wrong keyspace). I added a "currentKeyspace" field to the Conn struct so the statement caching code has it available to use as part of the cache key. --- cassandra_test.go | 108 +++++++++++++++++++++++++++++++++++----------- conn.go | 30 ++++++++----- 2 files changed, 102 insertions(+), 36 deletions(-) diff --git a/cassandra_test.go b/cassandra_test.go index 2ffa79bf9..34ac62501 100644 --- a/cassandra_test.go +++ b/cassandra_test.go @@ -51,7 +51,7 @@ func createTable(s *Session, table string) error { return err } -func createSession(tb testing.TB) *Session { +func createCluster() *ClusterConfig { cluster := NewCluster(clusterHosts...) cluster.ProtoVersion = *flagProto cluster.CQLVersion = *flagCQL @@ -59,26 +59,37 @@ func createSession(tb testing.TB) *Session { cluster.Consistency = Quorum cluster.RetryPolicy.NumRetries = *flagRetry + return cluster +} + +func createKeyspace(tb testing.TB, cluster *ClusterConfig, keyspace string) { + session, err := cluster.CreateSession() + if err != nil { + tb.Fatal("createSession:", err) + } + if err = session.Query(`DROP KEYSPACE ` + keyspace).Exec(); err != nil { + tb.Log("drop keyspace:", err) + } + if err := session.Query(fmt.Sprintf(`CREATE KEYSPACE %s + WITH replication = { + 'class' : 'SimpleStrategy', + 'replication_factor' : %d + }`, keyspace, *flagRF)).Consistency(All).Exec(); err != nil { + tb.Fatalf("error creating keyspace %s: %v", keyspace, err) + } + tb.Logf("Created keyspace %s", keyspace) + session.Close() +} + +func createSession(tb testing.TB) *Session { + cluster := createCluster() + + // Drop and re-create the keyspace once. Different tests should use their own + // individual tables, but can assume that the table does not exist before. initOnce.Do(func() { - session, err := cluster.CreateSession() - if err != nil { - tb.Fatal("createSession:", err) - } - // Drop and re-create the keyspace once. Different tests should use their own - // individual tables, but can assume that the table does not exist before. - if err := session.Query(`DROP KEYSPACE gocql_test`).Exec(); err != nil { - tb.Log("drop keyspace:", err) - } - if err := session.Query(fmt.Sprintf(`CREATE KEYSPACE gocql_test - WITH replication = { - 'class' : 'SimpleStrategy', - 'replication_factor' : %d - }`, *flagRF)).Consistency(All).Exec(); err != nil { - tb.Fatal("create keyspace:", err) - } - tb.Log("Created keyspace") - session.Close() + createKeyspace(tb, cluster, "gocql_test") }) + cluster.Keyspace = "gocql_test" session, err := cluster.CreateSession() if err != nil { @@ -941,19 +952,19 @@ func TestPreparedCacheEviction(t *testing.T) { //Walk through all the configured hosts and test cache retention and eviction var selFound, insFound, updFound, delFound, selEvict bool for i := range session.cfg.Hosts { - _, ok := stmtsLRU.lru.Get(session.cfg.Hosts[i] + ":9042SELECT id,mod FROM prepcachetest WHERE id = 1") + _, ok := stmtsLRU.lru.Get(session.cfg.Hosts[i] + ":9042gocql_testSELECT id,mod FROM prepcachetest WHERE id = 1") selFound = selFound || ok - _, ok = stmtsLRU.lru.Get(session.cfg.Hosts[i] + ":9042INSERT INTO prepcachetest (id,mod) VALUES (?, ?)") + _, ok = stmtsLRU.lru.Get(session.cfg.Hosts[i] + ":9042gocql_testINSERT INTO prepcachetest (id,mod) VALUES (?, ?)") insFound = insFound || ok - _, ok = stmtsLRU.lru.Get(session.cfg.Hosts[i] + ":9042UPDATE prepcachetest SET mod = ? WHERE id = ?") + _, ok = stmtsLRU.lru.Get(session.cfg.Hosts[i] + ":9042gocql_testUPDATE prepcachetest SET mod = ? WHERE id = ?") updFound = updFound || ok - _, ok = stmtsLRU.lru.Get(session.cfg.Hosts[i] + ":9042DELETE FROM prepcachetest WHERE id = ?") + _, ok = stmtsLRU.lru.Get(session.cfg.Hosts[i] + ":9042gocql_testDELETE FROM prepcachetest WHERE id = ?") delFound = delFound || ok - _, ok = stmtsLRU.lru.Get(session.cfg.Hosts[i] + ":9042SELECT id,mod FROM prepcachetest WHERE id = 0") + _, ok = stmtsLRU.lru.Get(session.cfg.Hosts[i] + ":9042gocql_testSELECT id,mod FROM prepcachetest WHERE id = 0") selEvict = selEvict || !ok } @@ -974,6 +985,53 @@ func TestPreparedCacheEviction(t *testing.T) { } } +func TestPreparedCacheKey(t *testing.T) { + session := createSession(t) + defer session.Close() + + // create a second keyspace + cluster2 := createCluster() + createKeyspace(t, cluster2, "gocql_test2") + cluster2.Keyspace = "gocql_test2" + session2, err := cluster2.CreateSession() + if err != nil { + t.Fatal("create session:", err) + } + defer session2.Close() + + // both keyspaces have a table named "test_stmt_cache_key" + if err := createTable(session, "CREATE TABLE test_stmt_cache_key (id varchar primary key, field varchar)"); err != nil { + t.Fatal("create table:", err) + } + if err := createTable(session2, "CREATE TABLE test_stmt_cache_key (id varchar primary key, field varchar)"); err != nil { + t.Fatal("create table:", err) + } + + // both tables have a single row with the same partition key but different column value + if err = session.Query(`INSERT INTO test_stmt_cache_key (id, field) VALUES (?, ?)`, "key", "one").Exec(); err != nil { + t.Fatal("insert:", err) + } + if err = session2.Query(`INSERT INTO test_stmt_cache_key (id, field) VALUES (?, ?)`, "key", "two").Exec(); err != nil { + t.Fatal("insert:", err) + } + + // should be able to see different values in each keyspace + var value string + if err = session.Query("SELECT field FROM test_stmt_cache_key WHERE id = ?", "key").Scan(&value); err != nil { + t.Fatal("select:", err) + } + if value != "one" { + t.Errorf("Expected one, got %s", value) + } + + if err = session2.Query("SELECT field FROM test_stmt_cache_key WHERE id = ?", "key").Scan(&value); err != nil { + t.Fatal("select:", err) + } + if value != "two" { + t.Errorf("Expected two, got %s", value) + } +} + //TestMarshalFloat64Ptr tests to see that a pointer to a float64 is marshalled correctly. func TestMarshalFloat64Ptr(t *testing.T) { session := createSession(t) @@ -1051,7 +1109,7 @@ func TestVarint(t *testing.T) { err := session.Query("SELECT test FROM varint_test").Scan(&result64) if err == nil || strings.Index(err.Error(), "out of range") == -1 { - t.Errorf("expected our of range error since value is too big for int64") + t.Errorf("expected out of range error since value is too big for int64") } // value not set in cassandra, leave bind variable empty diff --git a/conn.go b/conn.go index 29a51c124..038d2dc9d 100644 --- a/conn.go +++ b/conn.go @@ -66,11 +66,12 @@ type Conn struct { calls []callReq nwait int32 - pool ConnectionPool - compressor Compressor - auth Authenticator - addr string - version uint8 + pool ConnectionPool + compressor Compressor + auth Authenticator + addr string + version uint8 + currentKeyspace string closedMu sync.RWMutex isClosed bool @@ -310,7 +311,10 @@ func (c *Conn) ping() error { func (c *Conn) prepareStatement(stmt string, trace Tracer) (*QueryInfo, error) { stmtsLRU.mu.Lock() - if val, ok := stmtsLRU.lru.Get(c.addr + stmt); ok { + + stmtCacheKey := c.addr + c.currentKeyspace + stmt + + if val, ok := stmtsLRU.lru.Get(stmtCacheKey); ok { flight := val.(*inflightPrepare) stmtsLRU.mu.Unlock() flight.wg.Wait() @@ -319,7 +323,7 @@ func (c *Conn) prepareStatement(stmt string, trace Tracer) (*QueryInfo, error) { flight := new(inflightPrepare) flight.wg.Add(1) - stmtsLRU.lru.Add(c.addr+stmt, flight) + stmtsLRU.lru.Add(stmtCacheKey, flight) stmtsLRU.mu.Unlock() resp, err := c.exec(&prepareFrame{Stmt: stmt}, trace) @@ -345,7 +349,7 @@ func (c *Conn) prepareStatement(stmt string, trace Tracer) (*QueryInfo, error) { if err != nil { stmtsLRU.mu.Lock() - stmtsLRU.lru.Remove(c.addr + stmt) + stmtsLRU.lru.Remove(stmtCacheKey) stmtsLRU.mu.Unlock() } @@ -414,8 +418,9 @@ func (c *Conn) executeQuery(qry *Query) *Iter { return &Iter{} case RequestErrUnprepared: stmtsLRU.mu.Lock() - if _, ok := stmtsLRU.lru.Get(c.addr + qry.stmt); ok { - stmtsLRU.lru.Remove(c.addr + qry.stmt) + stmtCacheKey := c.addr + c.currentKeyspace + qry.stmt + if _, ok := stmtsLRU.lru.Get(stmtCacheKey); ok { + stmtsLRU.lru.Remove(stmtCacheKey) stmtsLRU.mu.Unlock() return c.executeQuery(qry) } @@ -470,6 +475,9 @@ func (c *Conn) UseKeyspace(keyspace string) error { default: return NewErrProtocol("Unknown type in response to USE: %s", x) } + + c.currentKeyspace = keyspace + return nil } @@ -537,7 +545,7 @@ func (c *Conn) executeBatch(batch *Batch) error { stmt, found := stmts[string(x.StatementId)] if found { stmtsLRU.mu.Lock() - stmtsLRU.lru.Remove(c.addr + stmt) + stmtsLRU.lru.Remove(c.addr + c.currentKeyspace + stmt) stmtsLRU.mu.Unlock() } if found {