diff --git a/README.md b/README.md index 354d556f9..da6a4c52f 100644 --- a/README.md +++ b/README.md @@ -1,24 +1,22 @@ gocql ===== -The gocql package provides a database/sql driver for CQL, the Cassandra -query language. - -This package requires a recent version of Cassandra (≥ 1.2) that supports -CQL 3.0 and the new native protocol. The native protocol is still considered -beta and must be enabled manually in Cassandra 1.2 by setting -"start_native_transport" to true in conf/cassandra.yaml. - -**Note:** gocql requires the tip version of Go, as some changes in the -`database/sql` have not made it into 1.0.x yet. There is -[a fork](https://github.com/titanous/gocql) that backports these changes -to Go 1.0.3. +Package gocql implements a fast and robust Cassandra driver for the +Go programming language. Installation ------------ go get github.com/tux21b/gocql + +Features +-------- + +* Modern Cassandra client for Cassandra 2.0 +* Built-In support for UUIDs (version 1 and 4) + + Example ------- @@ -26,48 +24,58 @@ Example package main import ( - "database/sql" "fmt" - _ "github.com/tux21b/gocql" + "github.com/tux21b/gocql" + "log" ) func main() { - db, err := sql.Open("gocql", "localhost:9042 keyspace=system") - if err != nil { - fmt.Println("Open error:", err) + // connect to your cluster + db := gocql.NewSession(gocql.Config{ + Nodes: []string{ + "192.168.1.1", + "192.168.1.2", + "192.168.1.3", + }, + Keyspace: "example", // (optional) + Consistency: gocql.ConQuorum, // (optional) + }) + defer db.Close() + + // simple query + var title, text string + if err := db.Query("SELECT title, text FROM posts WHERE title = ?", + "Lorem Ipsum").Scan(&title, &text); err != nil { + log.Fatal(err) } + fmt.Println(title, text) - rows, err := db.Query("SELECT keyspace_name FROM schema_keyspaces") - if err != nil { - fmt.Println("Query error:", err) + // iterator example + var titles []string + iter := db.Query("SELECT title FROM posts").Iter() + for iter.Scan(&title) { + titles = append(titles, title) + } + if err := iter.Close(); err != nil { + log.Fatal(err) } + fmt.Println(titles) - for rows.Next() { - var keyspace string - err = rows.Scan(&keyspace) - if err != nil { - fmt.Println("Scan error:", err) - } - fmt.Println(keyspace) + // insertion example (with custom consistency level) + if err := db.Query("INSERT INTO posts (title, text) VALUES (?, ?)", + "New Title", "foobar").Consistency(gocql.ConAny).Exec(); err != nil { + log.Fatal(err) } - if err = rows.Err(); err != nil { - fmt.Println("Iteration error:", err) - return + // prepared queries + query := gocql.NewQuery("SELECT text FROM posts WHERE title = ?") + if err := db.Do(query, "New Title").Scan(&text); err != nil { + log.Fatal(err) } + fmt.Println(text) } ``` -Please see `gocql_test.go` for some more advanced examples. - -Features --------- - -* Modern Cassandra client that is based on Cassandra's new native protocol -* Compatible with Go's `database/sql` package -* Built-In support for UUIDs (version 1 and 4) -* Optional frame compression (using snappy) - License ------- diff --git a/binary.go b/binary.go new file mode 100644 index 000000000..830511beb --- /dev/null +++ b/binary.go @@ -0,0 +1,247 @@ +package gocql + +import ( + "errors" + "net" +) + +const ( + protoRequest byte = 0x02 + protoResponse byte = 0x82 + + opError byte = 0x00 + opStartup byte = 0x01 + opReady byte = 0x02 + opAuthenticate byte = 0x03 + opOptions byte = 0x05 + opSupported byte = 0x06 + opQuery byte = 0x07 + opResult byte = 0x08 + opPrepare byte = 0x09 + opExecute byte = 0x0A + opRegister byte = 0x0B + opEvent byte = 0x0C + opBatch byte = 0x0D + opAuthChallenge byte = 0x0E + opAuthResponse byte = 0x0F + opAuthSuccess byte = 0x10 + + resultKindVoid = 1 + resultKindRows = 2 + resultKindKeyspace = 3 + resultKindPrepared = 4 + resultKindSchemaChanged = 5 + + flagQueryValues uint8 = 1 + + headerSize = 8 +) + +var ErrInvalid = errors.New("invalid response") + +type buffer []byte + +func (b *buffer) writeInt(v int32) { + p := b.grow(4) + (*b)[p] = byte(v >> 24) + (*b)[p+1] = byte(v >> 16) + (*b)[p+2] = byte(v >> 8) + (*b)[p+3] = byte(v) +} + +func (b *buffer) writeShort(v uint16) { + p := b.grow(2) + (*b)[p] = byte(v >> 8) + (*b)[p+1] = byte(v) +} + +func (b *buffer) writeString(v string) { + b.writeShort(uint16(len(v))) + p := b.grow(len(v)) + copy((*b)[p:], v) +} + +func (b *buffer) writeLongString(v string) { + b.writeInt(int32(len(v))) + p := b.grow(len(v)) + copy((*b)[p:], v) +} + +func (b *buffer) writeUUID() { +} + +func (b *buffer) writeStringList(v []string) { + b.writeShort(uint16(len(v))) + for i := range v { + b.writeString(v[i]) + } +} + +func (b *buffer) writeByte(v byte) { + p := b.grow(1) + (*b)[p] = v +} + +func (b *buffer) writeBytes(v []byte) { + if v == nil { + b.writeInt(-1) + return + } + b.writeInt(int32(len(v))) + p := b.grow(len(v)) + copy((*b)[p:], v) +} + +func (b *buffer) writeShortBytes(v []byte) { + b.writeShort(uint16(len(v))) + p := b.grow(len(v)) + copy((*b)[p:], v) +} + +func (b *buffer) writeInet(ip net.IP, port int) { + p := b.grow(1 + len(ip)) + (*b)[p] = byte(len(ip)) + copy((*b)[p+1:], ip) + b.writeInt(int32(port)) +} + +func (b *buffer) writeConsistency() { +} + +func (b *buffer) writeStringMap(v map[string]string) { + b.writeShort(uint16(len(v))) + for key, value := range v { + b.writeString(key) + b.writeString(value) + } +} + +func (b *buffer) writeStringMultimap(v map[string][]string) { + b.writeShort(uint16(len(v))) + for key, values := range v { + b.writeString(key) + b.writeStringList(values) + } +} + +func (b *buffer) setHeader(version, flags, stream, opcode uint8) { + (*b)[0] = version + (*b)[1] = flags + (*b)[2] = stream + (*b)[3] = opcode +} + +func (b *buffer) setLength(length int) { + (*b)[4] = byte(length >> 24) + (*b)[5] = byte(length >> 16) + (*b)[6] = byte(length >> 8) + (*b)[7] = byte(length) +} + +func (b *buffer) Length() int { + return int((*b)[4])<<24 | int((*b)[5])<<16 | int((*b)[6])<<8 | int((*b)[7]) +} + +func (b *buffer) grow(n int) int { + if len(*b)+n >= cap(*b) { + buf := make(buffer, len(*b), len(*b)*2+n) + copy(buf, *b) + *b = buf + } + p := len(*b) + *b = (*b)[:p+n] + return p +} + +func (b *buffer) skipHeader() { + *b = (*b)[headerSize:] +} + +func (b *buffer) readInt() int { + if len(*b) < 4 { + panic(ErrInvalid) + } + v := int((*b)[0])<<24 | int((*b)[1])<<16 | int((*b)[2])<<8 | int((*b)[3]) + *b = (*b)[4:] + return v +} + +func (b *buffer) readShort() uint16 { + if len(*b) < 2 { + panic(ErrInvalid) + } + v := uint16((*b)[0])<<8 | uint16((*b)[1]) + *b = (*b)[2:] + return v +} + +func (b *buffer) readString() string { + n := int(b.readShort()) + if len(*b) < n { + panic(ErrInvalid) + } + v := string((*b)[:n]) + *b = (*b)[n:] + return v +} + +func (b *buffer) readBytes() []byte { + n := b.readInt() + if n < 0 { + return nil + } + if len(*b) < n { + panic(ErrInvalid) + } + v := (*b)[:n] + *b = (*b)[n:] + return v +} + +func (b *buffer) readShortBytes() []byte { + n := int(b.readShort()) + if len(*b) < n { + panic(ErrInvalid) + } + v := (*b)[:n] + *b = (*b)[n:] + return v +} + +func (b *buffer) readTypeInfo() *TypeInfo { + x := b.readShort() + typ := &TypeInfo{Type: Type(x)} + switch typ.Type { + case TypeCustom: + typ.Custom = b.readString() + case TypeMap: + typ.Key = b.readTypeInfo() + fallthrough + case TypeList, TypeSet: + typ.Value = b.readTypeInfo() + } + return typ +} + +func (b *buffer) readMetaData() []columnInfo { + flags := b.readInt() + numColumns := b.readInt() + globalKeyspace := "" + globalTable := "" + if flags&1 != 0 { + globalKeyspace = b.readString() + globalTable = b.readString() + } + info := make([]columnInfo, numColumns) + for i := 0; i < numColumns; i++ { + info[i].Keyspace = globalKeyspace + info[i].Table = globalTable + if flags&1 == 0 { + info[i].Keyspace = b.readString() + info[i].Table = b.readString() + } + info[i].Name = b.readString() + info[i].TypeInfo = b.readTypeInfo() + } + return info +} diff --git a/conn.go b/conn.go new file mode 100644 index 000000000..d56ef521d --- /dev/null +++ b/conn.go @@ -0,0 +1,176 @@ +package gocql + +import ( + "io" + "net" + "sync" + "sync/atomic" +) + +type queryInfo struct { + id []byte + args []columnInfo + rval []columnInfo + avail chan bool +} + +type connection struct { + conn net.Conn + uniq chan uint8 + reply []chan buffer + waiting uint64 + + prepMu sync.Mutex + prep map[string]*queryInfo +} + +func connect(addr string, cfg *Config) (*connection, error) { + conn, err := net.Dial("tcp", addr) + if err != nil { + return nil, err + } + c := &connection{ + conn: conn, + uniq: make(chan uint8, 64), + reply: make([]chan buffer, 64), + prep: make(map[string]*queryInfo), + } + for i := 0; i < cap(c.uniq); i++ { + c.uniq <- uint8(i) + } + + go c.recv() + + frame := make(buffer, headerSize) + frame.setHeader(protoRequest, 0, 0, opStartup) + frame.writeStringMap(map[string]string{ + "CQL_VERSION": cfg.CQLVersion, + }) + frame.setLength(len(frame) - headerSize) + + frame = c.request(frame) + + if cfg.Keyspace != "" { + qry := &Query{stmt: "USE " + cfg.Keyspace} + frame, err = c.executeQuery(qry) + } + + return c, nil +} + +func (c *connection) recv() { + for { + frame := make(buffer, headerSize, headerSize+512) + if _, err := io.ReadFull(c.conn, frame); err != nil { + return + } + if frame[0] != protoResponse { + continue + } + if length := frame.Length(); length > 0 { + frame.grow(frame.Length()) + io.ReadFull(c.conn, frame[headerSize:]) + } + c.dispatch(frame) + } + panic("not possible") +} + +func (c *connection) request(frame buffer) buffer { + id := <-c.uniq + frame[2] = id + c.reply[id] = make(chan buffer, 1) + + for { + w := atomic.LoadUint64(&c.waiting) + if atomic.CompareAndSwapUint64(&c.waiting, w, w|(1<= 128 { + return + } + for { + w := atomic.LoadUint64(&c.waiting) + if w&(1< 0 { + info = c.prepareQuery(query.stmt) + } + + frame := make(buffer, headerSize, headerSize+512) + frame.setHeader(protoRequest, 0, 0, opQuery) + frame.writeLongString(query.stmt) + frame.writeShort(uint16(query.cons)) + flags := uint8(0) + if len(query.args) > 0 { + flags |= flagQueryValues + } + frame.writeByte(flags) + if len(query.args) > 0 { + frame.writeShort(uint16(len(query.args))) + for i := 0; i < len(query.args); i++ { + val, err := Marshal(info.args[i].TypeInfo, query.args[i]) + if err != nil { + return nil, err + } + frame.writeBytes(val) + } + } + frame.setLength(len(frame) - headerSize) + + frame = c.request(frame) + + if frame[3] == opError { + frame.skipHeader() + code := frame.readInt() + desc := frame.readString() + return nil, Error{code, desc} + } + return frame, nil +} diff --git a/convert.go b/convert.go deleted file mode 100644 index 5db9b9808..000000000 --- a/convert.go +++ /dev/null @@ -1,240 +0,0 @@ -// Copyright (c) 2012 The gocql Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package gocql - -import ( - "database/sql/driver" - "encoding/binary" - "fmt" - "github.com/tux21b/gocql/uuid" - "math" - "reflect" - "strconv" - "time" -) - -const ( - typeCustom uint16 = 0x0000 - typeAscii uint16 = 0x0001 - typeBigInt uint16 = 0x0002 - typeBlob uint16 = 0x0003 - typeBool uint16 = 0x0004 - typeCounter uint16 = 0x0005 - typeDecimal uint16 = 0x0006 - typeDouble uint16 = 0x0007 - typeFloat uint16 = 0x0008 - typeInt uint16 = 0x0009 - typeText uint16 = 0x000A - typeTimestamp uint16 = 0x000B - typeUUID uint16 = 0x000C - typeVarchar uint16 = 0x000D - typeVarint uint16 = 0x000E - typeTimeUUID uint16 = 0x000F - typeList uint16 = 0x0020 - typeMap uint16 = 0x0021 - typeSet uint16 = 0x0022 -) - -func decode(b []byte, t uint16) driver.Value { - switch t { - case typeBool: - if len(b) >= 1 && b[0] != 0 { - return true - } - return false - case typeBlob: - return b - case typeVarchar, typeText, typeAscii: - return b - case typeInt: - return int64(int32(binary.BigEndian.Uint32(b))) - case typeBigInt: - return int64(binary.BigEndian.Uint64(b)) - case typeFloat: - return float64(math.Float32frombits(binary.BigEndian.Uint32(b))) - case typeDouble: - return math.Float64frombits(binary.BigEndian.Uint64(b)) - case typeTimestamp: - t := int64(binary.BigEndian.Uint64(b)) - sec := t / 1000 - nsec := (t - sec*1000) * 1000000 - return time.Unix(sec, nsec) - case typeUUID, typeTimeUUID: - return uuid.FromBytes(b) - default: - panic("unsupported type") - } - return b -} - -type columnEncoder struct { - columnTypes []uint16 -} - -func (e *columnEncoder) ColumnConverter(idx int) ValueConverter { - switch e.columnTypes[idx] { - case typeInt: - return ValueConverter(encInt) - case typeBigInt: - return ValueConverter(encBigInt) - case typeFloat: - return ValueConverter(encFloat) - case typeDouble: - return ValueConverter(encDouble) - case typeBool: - return ValueConverter(encBool) - case typeVarchar, typeText, typeAscii: - return ValueConverter(encVarchar) - case typeBlob: - return ValueConverter(encBlob) - case typeTimestamp: - return ValueConverter(encTimestamp) - case typeUUID, typeTimeUUID: - return ValueConverter(encUUID) - } - panic("not implemented") -} - -type ValueConverter func(v interface{}) (driver.Value, error) - -func (vc ValueConverter) ConvertValue(v interface{}) (driver.Value, error) { - return vc(v) -} - -func encBool(v interface{}) (driver.Value, error) { - b, err := driver.Bool.ConvertValue(v) - if err != nil { - return nil, err - } - if b.(bool) { - return []byte{1}, nil - } - return []byte{0}, nil -} - -func encInt(v interface{}) (driver.Value, error) { - x, err := driver.Int32.ConvertValue(v) - if err != nil { - return nil, err - } - b := make([]byte, 4) - binary.BigEndian.PutUint32(b, uint32(x.(int64))) - return b, nil -} - -func encBigInt(v interface{}) (driver.Value, error) { - x := reflect.Indirect(reflect.ValueOf(v)).Interface() - b := make([]byte, 8) - binary.BigEndian.PutUint64(b, uint64(x.(int64))) - return b, nil -} - -func encVarchar(v interface{}) (driver.Value, error) { - x, err := driver.String.ConvertValue(v) - if err != nil { - return nil, err - } - return []byte(x.(string)), nil -} - -func encFloat(v interface{}) (driver.Value, error) { - x, err := driver.DefaultParameterConverter.ConvertValue(v) - if err != nil { - return nil, err - } - var f float64 - switch x := x.(type) { - case float64: - f = x - case int64: - f = float64(x) - case []byte: - if f, err = strconv.ParseFloat(string(x), 64); err != nil { - return nil, err - } - default: - return nil, fmt.Errorf("can not convert %T to float64", x) - } - b := make([]byte, 4) - binary.BigEndian.PutUint32(b, math.Float32bits(float32(f))) - return b, nil -} - -func encDouble(v interface{}) (driver.Value, error) { - x, err := driver.DefaultParameterConverter.ConvertValue(v) - if err != nil { - return nil, err - } - var f float64 - switch x := x.(type) { - case float64: - f = x - case int64: - f = float64(x) - case []byte: - if f, err = strconv.ParseFloat(string(x), 64); err != nil { - return nil, err - } - default: - return nil, fmt.Errorf("can not convert %T to float64", x) - } - b := make([]byte, 8) - binary.BigEndian.PutUint64(b, math.Float64bits(f)) - return b, nil -} - -func encTimestamp(v interface{}) (driver.Value, error) { - x, err := driver.DefaultParameterConverter.ConvertValue(v) - if err != nil { - return nil, err - } - var millis int64 - switch x := x.(type) { - case time.Time: - x = x.In(time.UTC) - millis = x.UnixNano() / 1000000 - default: - return nil, fmt.Errorf("can not convert %T to a timestamp", x) - } - b := make([]byte, 8) - binary.BigEndian.PutUint64(b, uint64(millis)) - return b, nil -} - -func encBlob(v interface{}) (driver.Value, error) { - x, err := driver.DefaultParameterConverter.ConvertValue(v) - if err != nil { - return nil, err - } - var b []byte - switch x := x.(type) { - case string: - b = []byte(x) - case []byte: - b = x - default: - return nil, fmt.Errorf("can not convert %T to a []byte", x) - } - return b, nil -} - -func encUUID(v interface{}) (driver.Value, error) { - var u uuid.UUID - switch v := v.(type) { - case string: - var err error - u, err = uuid.ParseUUID(v) - if err != nil { - return nil, err - } - case []byte: - u = uuid.FromBytes(v) - case uuid.UUID: - u = v - default: - return nil, fmt.Errorf("can not convert %T to a UUID", v) - } - return u.Bytes(), nil -} diff --git a/doc.go b/doc.go new file mode 100644 index 000000000..bf8489586 --- /dev/null +++ b/doc.go @@ -0,0 +1,9 @@ +// Copyright (c) 2012 The gocql Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package gocql implements a fast and robust Cassandra driver for the +// Go programming language. +package gocql + +// TODO(tux21b): write more docs. diff --git a/gocql.go b/gocql.go index f09ee696f..890269702 100644 --- a/gocql.go +++ b/gocql.go @@ -2,616 +2,213 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The gocql package provides a database/sql driver for CQL, the Cassandra -// query language. -// -// This package requires a recent version of Cassandra (≥ 1.2) that supports -// CQL 3.0 and the new native protocol. The native protocol is still considered -// beta and must be enabled manually in Cassandra 1.2 by setting -// "start_native_transport" to true in conf/cassandra.yaml. -// -// Example Usage: -// -// db, err := sql.Open("gocql", "localhost:9042 keyspace=system") -// // ... -// rows, err := db.Query("SELECT keyspace_name FROM schema_keyspaces") -// // ... -// for rows.Next() { -// var keyspace string -// err = rows.Scan(&keyspace) -// // ... -// fmt.Println(keyspace) -// } -// if err := rows.Err(); err != nil { -// // ... -// } -// package gocql import ( - "bytes" - "code.google.com/p/snappy-go/snappy" - "database/sql" - "database/sql/driver" - "encoding/binary" + "errors" "fmt" - "io" - "net" "strings" - "time" ) -const ( - protoRequest byte = 0x01 - protoResponse byte = 0x81 - - opError byte = 0x00 - opStartup byte = 0x01 - opReady byte = 0x02 - opAuthenticate byte = 0x03 - opCredentials byte = 0x04 - opOptions byte = 0x05 - opSupported byte = 0x06 - opQuery byte = 0x07 - opResult byte = 0x08 - opPrepare byte = 0x09 - opExecute byte = 0x0A - opLAST byte = 0x0A // not a real opcode -- used to check for valid opcodes - - flagCompressed byte = 0x01 - - keyVersion string = "CQL_VERSION" - keyCompression string = "COMPRESSION" - keyspaceQuery string = "USE " -) - -var consistencyLevels = map[string]byte{"any": 0x00, "one": 0x01, "two": 0x02, - "three": 0x03, "quorum": 0x04, "all": 0x05, "local_quorum": 0x06, "each_quorum": 0x07} - -type drv struct{} - -func (d drv) Open(name string) (driver.Conn, error) { - return Open(name) +type Config struct { + Nodes []string + CQLVersion string + Keyspace string + Consistency Consistency + DefaultPort int } -type connection struct { - c net.Conn - address string - alive bool - pool *pool -} - -type pool struct { - connections []*connection - i int - keyspace string - version string - compression string - consistency byte - dead bool - stop chan struct{} -} - -func Open(name string) (*pool, error) { - parts := strings.Split(name, " ") - var addresses []string - if len(parts) >= 1 { - addresses = strings.Split(parts[0], ",") - } - - version := "3.0.0" - var ( - keyspace string - compression string - consistency byte = 0x01 - ok bool - ) - for i := 1; i < len(parts); i++ { - switch { - case parts[i] == "": - continue - case strings.HasPrefix(parts[i], "keyspace="): - keyspace = strings.TrimSpace(parts[i][9:]) - case strings.HasPrefix(parts[i], "compression="): - compression = strings.TrimSpace(parts[i][12:]) - if compression != "snappy" { - return nil, fmt.Errorf("unknown compression algorithm %q", - compression) - } - case strings.HasPrefix(parts[i], "version="): - version = strings.TrimSpace(parts[i][8:]) - case strings.HasPrefix(parts[i], "consistency="): - cs := strings.TrimSpace(parts[i][12:]) - if consistency, ok = consistencyLevels[cs]; !ok { - return nil, fmt.Errorf("unknown consistency level %q", cs) - } - default: - return nil, fmt.Errorf("unsupported option %q", parts[i]) - } - } - - pool := &pool{ - keyspace: keyspace, - version: version, - compression: compression, - consistency: consistency, - stop: make(chan struct{}), - } - - for _, address := range addresses { - pool.connections = append(pool.connections, &connection{address: address, pool: pool}) - } - - pool.join() - - return pool, nil -} - -func (cn *connection) open() { - cn.alive = false - - var err error - cn.c, err = net.Dial("tcp", cn.address) - if err != nil { - return - } - - var ( - version = cn.pool.version - compression = cn.pool.compression - keyspace = cn.pool.keyspace - ) - - b := &bytes.Buffer{} - - if compression != "" { - binary.Write(b, binary.BigEndian, uint16(2)) - } else { - binary.Write(b, binary.BigEndian, uint16(1)) - } - - binary.Write(b, binary.BigEndian, uint16(len(keyVersion))) - b.WriteString(keyVersion) - binary.Write(b, binary.BigEndian, uint16(len(version))) - b.WriteString(version) - - if compression != "" { - binary.Write(b, binary.BigEndian, uint16(len(keyCompression))) - b.WriteString(keyCompression) - binary.Write(b, binary.BigEndian, uint16(len(compression))) - b.WriteString(compression) +func (c *Config) normalize() { + if c.CQLVersion == "" { + c.CQLVersion = "3.0.0" } - - if err := cn.sendUncompressed(opStartup, b.Bytes()); err != nil { - return - } - - opcode, _, err := cn.recv() - if err != nil { - return - } - if opcode != opReady { - return + if c.DefaultPort == 0 { + c.DefaultPort = 9042 } - - if keyspace != "" { - cn.UseKeyspace(keyspace) - } - - cn.alive = true -} - -// close a connection actively, typically used when there's an error and we want to ensure -// we don't repeatedly try to use the broken connection -func (cn *connection) close() { - cn.c.Close() - cn.c = nil // ensure we generate ErrBadConn when cn gets reused - cn.alive = false - - // Check if the entire pool is dead - for _, cn := range cn.pool.connections { - if cn.alive { - return + for i := 0; i < len(c.Nodes); i++ { + c.Nodes[i] = strings.TrimSpace(c.Nodes[i]) + if strings.IndexByte(c.Nodes[i], ':') < 0 { + c.Nodes[i] = fmt.Sprintf("%s:%d", c.Nodes[i], c.DefaultPort) } } - cn.pool.dead = false -} - -// explicitly send a request as uncompressed -// This is only really needed for the "startup" handshake -func (cn *connection) sendUncompressed(opcode byte, body []byte) error { - return cn._send(opcode, body, false) } -func (cn *connection) send(opcode byte, body []byte) error { - return cn._send(opcode, body, cn.pool.compression == "snappy" && len(body) > 0) +type Session struct { + cfg *Config + pool []*connection } -func (cn *connection) _send(opcode byte, body []byte, compression bool) error { - if cn.c == nil { - return driver.ErrBadConn - } - var flags byte = 0x00 - if compression { - var err error - body, err = snappy.Encode(nil, body) - if err != nil { - return err +func NewSession(cfg Config) *Session { + cfg.normalize() + pool := make([]*connection, 0, len(cfg.Nodes)) + for _, address := range cfg.Nodes { + con, err := connect(address, &cfg) + if err == nil { + pool = append(pool, con) } - flags = flagCompressed - } - frame := make([]byte, len(body)+8) - frame[0] = protoRequest - frame[1] = flags - frame[2] = 0 - frame[3] = opcode - binary.BigEndian.PutUint32(frame[4:8], uint32(len(body))) - copy(frame[8:], body) - if _, err := cn.c.Write(frame); err != nil { - return err } - return nil + return &Session{cfg: &cfg, pool: pool} } -func (cn *connection) recv() (byte, []byte, error) { - if cn.c == nil { - return 0, nil, driver.ErrBadConn - } - header := make([]byte, 8) - if _, err := io.ReadFull(cn.c, header); err != nil { - cn.close() // better assume that the connection is broken (may have read some bytes) - return 0, nil, err - } - // verify that the frame starts with version==1 and req/resp flag==response - // this may be overly conservative in that future versions may be backwards compatible - // in that case simply amend the check... - if header[0] != protoResponse { - cn.close() - return 0, nil, fmt.Errorf("unsupported frame version or not a response: 0x%x (header=%v)", header[0], header) - } - // verify that the flags field has only a single flag set, again, this may - // be overly conservative if additional flags are backwards-compatible - if header[1] > 1 { - cn.close() - return 0, nil, fmt.Errorf("unsupported frame flags: 0x%x (header=%v)", header[1], header) - } - opcode := header[3] - if opcode > opLAST { - cn.close() - return 0, nil, fmt.Errorf("unknown opcode: 0x%x (header=%v)", opcode, header) - } - length := binary.BigEndian.Uint32(header[4:8]) - var body []byte - if length > 0 { - if length > 256*1024*1024 { // spec says 256MB is max - cn.close() - return 0, nil, fmt.Errorf("frame too large: %d (header=%v)", length, header) - } - body = make([]byte, length) - if _, err := io.ReadFull(cn.c, body); err != nil { - cn.close() // better assume that the connection is broken - return 0, nil, err - } - } - if header[1]&flagCompressed != 0 && cn.pool.compression == "snappy" { - var err error - body, err = snappy.Decode(nil, body) - if err != nil { - cn.close() - return 0, nil, err - } - } - if opcode == opError { - code := binary.BigEndian.Uint32(body[0:4]) - msglen := binary.BigEndian.Uint16(body[4:6]) - msg := string(body[6 : 6+msglen]) - return opcode, body, Error{Code: int(code), Msg: msg} - } - return opcode, body, nil -} - -func (p *pool) conn() (*connection, error) { - if p.dead { - return nil, driver.ErrBadConn - } - - totalConnections := len(p.connections) - start := p.i + 1 // make sure that we start from the next position in the ring - - for i := 0; i < totalConnections; i++ { - idx := (i + start) % totalConnections - cn := p.connections[idx] - if cn.alive { - p.i = idx // set the new 'i' so the ring will start again in the right place - return cn, nil - } +func (s *Session) Query(stmt string, args ...interface{}) *Query { + return &Query{ + stmt: stmt, + args: args, + cons: s.cfg.Consistency, + ctx: s, } - - // we've exhausted the pool, gonna have a bad time - p.dead = true - return nil, driver.ErrBadConn } -func (p *pool) join() { - p.reconnect() - - // Every 1 second, we want to try reconnecting to disconnected nodes - go func() { - for { - select { - case <-p.stop: - return - default: - p.reconnect() - time.Sleep(time.Second) - } - } - }() +func (s *Session) executeQuery(query *Query) (buffer, error) { + // TODO(tux21b): do something clever here + return s.pool[0].executeQuery(query) } -func (p *pool) reconnect() { - for _, cn := range p.connections { - if !cn.alive { - cn.open() - } - } +func (s *Session) Close() { + return } -func (p *pool) Begin() (driver.Tx, error) { - if p.dead { - return nil, driver.ErrBadConn - } - return p, nil -} +type Consistency uint16 -func (p *pool) Commit() error { - if p.dead { - return driver.ErrBadConn - } - return nil -} +const ( + ConAny Consistency = 0x0000 + ConOne Consistency = 0x0001 + ConTwo Consistency = 0x0002 + ConThree Consistency = 0x0003 + ConQuorum Consistency = 0x0004 + ConAll Consistency = 0x0005 + ConLocalQuorum Consistency = 0x0006 + ConEachQuorum Consistency = 0x0007 + ConSerial Consistency = 0x0008 + ConLocalSerial Consistency = 0x0009 +) -func (p *pool) Close() error { - if p.dead { - return driver.ErrBadConn - } - for _, cn := range p.connections { - cn.close() - } - p.stop <- struct{}{} - p.dead = true - return nil -} +var ErrNotFound = errors.New("not found") -func (p *pool) Rollback() error { - if p.dead { - return driver.ErrBadConn +type Query struct { + stmt string + args []interface{} + cons Consistency + ctx interface { + executeQuery(query *Query) (buffer, error) } - return nil } -func (p *pool) Prepare(query string) (driver.Stmt, error) { - // Explicitly check if the query is a "USE " - // Since it needs to be special cased and run on each server - if strings.HasPrefix(query, keyspaceQuery) { - keyspace := query[len(keyspaceQuery):] - p.UseKeyspace(keyspace) - return &statement{}, nil - } - - for { - cn, err := p.conn() - if err != nil { - return nil, err - } - st, err := cn.Prepare(query) - if err != nil { - // the cn has gotten marked as dead already - if p.dead { - // The entire pool is dead, so we bubble up the ErrBadConn - return nil, driver.ErrBadConn - } else { - continue // Retry request on another cn - } - } - return st, nil - } -} +var ErrQueryUnbound = errors.New("can not execute unbound query") -func (p *pool) UseKeyspace(keyspace string) { - p.keyspace = keyspace - for _, cn := range p.connections { - cn.UseKeyspace(keyspace) - } +func NewQuery(stmt string) *Query { + return &Query{stmt: stmt, cons: ConQuorum} } -func (cn *connection) UseKeyspace(keyspace string) error { - st, err := cn.Prepare(keyspaceQuery + keyspace) +func (q *Query) Exec() error { + frame, err := q.request() if err != nil { return err } - if _, err = st.Exec([]driver.Value{}); err != nil { - return err - } - return nil -} - -func (cn *connection) Prepare(query string) (driver.Stmt, error) { - body := make([]byte, len(query)+4) - binary.BigEndian.PutUint32(body[0:4], uint32(len(query))) - copy(body[4:], []byte(query)) - if err := cn.send(opPrepare, body); err != nil { - return nil, err - } - opcode, body, err := cn.recv() - if err != nil { - return nil, err - } - if opcode != opResult || binary.BigEndian.Uint32(body) != 4 { - return nil, fmt.Errorf("expected prepared result") + if frame[3] == opResult { + frame.skipHeader() + kind := frame.readInt() + if kind == 3 { + keyspace := frame.readString() + fmt.Println("set keyspace:", keyspace) + } else { + } } - n := int(binary.BigEndian.Uint16(body[4:])) - prepared := body[6 : 6+n] - columns, meta, _ := parseMeta(body[6+n:]) - return &statement{cn: cn, query: query, - prepared: prepared, columns: columns, meta: meta}, nil -} - -type statement struct { - cn *connection - query string - prepared []byte - columns []string - meta []uint16 -} - -func (s *statement) Close() error { return nil } -func (st *statement) ColumnConverter(idx int) driver.ValueConverter { - return (&columnEncoder{st.meta}).ColumnConverter(idx) +func (q *Query) request() (buffer, error) { + return q.ctx.executeQuery(q) } -func (st *statement) NumInput() int { - return len(st.columns) +func (q *Query) Consistency(cons Consistency) *Query { + q.cons = cons + return q } -func parseMeta(body []byte) ([]string, []uint16, int) { - flags := binary.BigEndian.Uint32(body) - globalTableSpec := flags&1 == 1 - columnCount := int(binary.BigEndian.Uint32(body[4:])) - i := 8 - if globalTableSpec { - l := int(binary.BigEndian.Uint16(body[i:])) - keyspace := string(body[i+2 : i+2+l]) - i += 2 + l - l = int(binary.BigEndian.Uint16(body[i:])) - tablename := string(body[i+2 : i+2+l]) - i += 2 + l - _, _ = keyspace, tablename - } - columns := make([]string, columnCount) - meta := make([]uint16, columnCount) - for c := 0; c < columnCount; c++ { - l := int(binary.BigEndian.Uint16(body[i:])) - columns[c] = string(body[i+2 : i+2+l]) - i += 2 + l - meta[c] = binary.BigEndian.Uint16(body[i:]) - i += 2 +func (q *Query) Scan(values ...interface{}) error { + found := false + iter := q.Iter() + if iter.Scan(values...) { + found = true } - return columns, meta, i -} - -func (st *statement) exec(v []driver.Value) error { - sz := 6 + len(st.prepared) - for i := range v { - if b, ok := v[i].([]byte); ok { - sz += len(b) + 4 - } - } - body, p := make([]byte, sz), 4+len(st.prepared) - binary.BigEndian.PutUint16(body, uint16(len(st.prepared))) - copy(body[2:], st.prepared) - binary.BigEndian.PutUint16(body[p-2:], uint16(len(v))) - for i := range v { - b, ok := v[i].([]byte) - if !ok { - return fmt.Errorf("unsupported type %T at column %d", v[i], i) - } - binary.BigEndian.PutUint32(body[p:], uint32(len(b))) - copy(body[p+4:], b) - p += 4 + len(b) - } - binary.BigEndian.PutUint16(body[p:], uint16(st.cn.pool.consistency)) - if err := st.cn.send(opExecute, body); err != nil { + if err := iter.Close(); err != nil { return err + } else if !found { + return ErrNotFound } return nil } -func (st *statement) Exec(v []driver.Value) (driver.Result, error) { - if st.cn == nil { - return nil, nil - } - if err := st.exec(v); err != nil { - return nil, err - } - opcode, body, err := st.cn.recv() +func (q *Query) Iter() *Iter { + iter := new(Iter) + frame, err := q.request() if err != nil { - return nil, err + iter.err = err + return iter } - _, _ = opcode, body - return nil, nil -} - -func (st *statement) Query(v []driver.Value) (driver.Rows, error) { - if err := st.exec(v); err != nil { - return nil, err + frame.skipHeader() + kind := frame.readInt() + if kind == resultKindRows { + iter.setFrame(frame) } - opcode, body, err := st.cn.recv() - if err != nil { - return nil, err - } - kind := binary.BigEndian.Uint32(body[0:4]) - if opcode != opResult || kind != 2 { - return nil, fmt.Errorf("expected rows as result") - } - columns, meta, n := parseMeta(body[4:]) - i := n + 4 - rows := &rows{ - columns: columns, - meta: meta, - numRows: int(binary.BigEndian.Uint32(body[i:])), - } - i += 4 - rows.body = body[i:] - return rows, nil + return iter } -type rows struct { - columns []string - meta []uint16 - body []byte - row int +type Iter struct { + err error + pos int numRows int + info []columnInfo + flags int + frame buffer +} + +func (iter *Iter) setFrame(frame buffer) { + info := frame.readMetaData() + iter.flags = 0 + iter.info = info + iter.numRows = frame.readInt() + iter.pos = 0 + iter.err = nil + iter.frame = frame +} + +func (iter *Iter) Scan(values ...interface{}) bool { + if iter.err != nil || iter.pos >= iter.numRows { + return false + } + iter.pos++ + if len(values) != len(iter.info) { + iter.err = errors.New("count mismatch") + return false + } + for i := 0; i < len(values); i++ { + data := iter.frame.readBytes() + if err := Unmarshal(iter.info[i].TypeInfo, data, values[i]); err != nil { + iter.err = err + return false + } + } + return true } -func (r *rows) Close() error { - return nil -} - -func (r *rows) Columns() []string { - return r.columns +func (iter *Iter) Close() error { + return iter.err } -func (r *rows) Next(values []driver.Value) error { - if r.row >= r.numRows { - return io.EOF - } - for column := 0; column < len(r.columns); column++ { - n := int32(binary.BigEndian.Uint32(r.body)) - r.body = r.body[4:] - if n >= 0 { - values[column] = decode(r.body[:n], r.meta[column]) - r.body = r.body[n:] - } else { - values[column] = nil - } - } - r.row++ - return nil +type columnInfo struct { + Keyspace string + Table string + Name string + TypeInfo *TypeInfo } type Error struct { - Code int - Msg string + Code int + Message string } func (e Error) Error() string { - return e.Msg -} - -func init() { - sql.Register("gocql", &drv{}) + return e.Message } diff --git a/gocql_test.go b/gocql_test.go index 64f08d693..4d271b1bb 100644 --- a/gocql_test.go +++ b/gocql_test.go @@ -1,42 +1,41 @@ -// Copyright (c) 2012 The gocql Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - package gocql import ( "bytes" - "database/sql" - "github.com/tux21b/gocql/uuid" + "fmt" "testing" "time" ) -func TestSimple(t *testing.T) { - db, err := sql.Open("gocql", "localhost:9042 keyspace=system") - if err != nil { - t.Fatal(err) - } +func TestConnect(t *testing.T) { + db := NewSession(Config{ + Nodes: []string{ + "127.0.0.1", + }, + Keyspace: "system", + Consistency: ConQuorum, + }) + defer db.Close() - rows, err := db.Query("SELECT keyspace_name FROM schema_keyspaces") - if err != nil { - t.Fatal(err) + for i := 0; i < 5; i++ { + db.Query("SELECT keyspace_name FROM schema_keyspaces WHERE keyspace_name = ?", + "system_auth").Exec() } - for rows.Next() { - var keyspace string - if err := rows.Scan(&keyspace); err != nil { - t.Fatal(err) - } + var keyspace string + var durable bool + iter := db.Query("SELECT keyspace_name, durable_writes FROM schema_keyspaces").Iter() + for iter.Scan(&keyspace, &durable) { + fmt.Println("Keyspace:", keyspace, durable) } - if err != nil { - t.Fatal(err) + if err := iter.Close(); err != nil { + fmt.Println(err) } } type Page struct { Title string - RevID uuid.UUID + RevID int Body string Hits int Protected bool @@ -45,67 +44,74 @@ type Page struct { } var pages = []*Page{ - &Page{"Frontpage", uuid.TimeUUID(), "Hello world!", 0, false, - time.Date(2012, 8, 20, 10, 0, 0, 0, time.UTC), nil}, - &Page{"Frontpage", uuid.TimeUUID(), "Hello modified world!", 0, false, + &Page{"Frontpage", 1, "Hello world!", 0, false, + time.Date(2012, 8, 20, 10, 0, 0, 0, time.UTC), []byte{}}, + &Page{"Frontpage", 2, "Hello modified world!", 0, false, time.Date(2012, 8, 22, 10, 0, 0, 0, time.UTC), []byte("img data\x00")}, - &Page{"LoremIpsum", uuid.TimeUUID(), "Lorem ipsum dolor sit amet", 12, - true, time.Date(2012, 8, 22, 10, 0, 8, 0, time.UTC), nil}, + &Page{"LoremIpsum", 3, "Lorem ipsum dolor sit amet", 12, + true, time.Date(2012, 8, 22, 10, 0, 8, 0, time.UTC), []byte{}}, } func TestWiki(t *testing.T) { - db, err := sql.Open("gocql", "localhost:9042 compression=snappy") - if err != nil { - t.Fatal(err) + db := NewSession(Config{ + Nodes: []string{"localhost"}, + Consistency: ConQuorum, + }) + + if err := db.Query("DROP KEYSPACE gocql_wiki").Exec(); err != nil { + t.Log("DROP KEYSPACE:", err) } - db.Exec("DROP KEYSPACE gocql_wiki") - if _, err := db.Exec(`CREATE KEYSPACE gocql_wiki - WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor' : 1 }`); err != nil { - t.Fatal(err) + + if err := db.Query(`CREATE KEYSPACE gocql_wiki + WITH replication = { + 'class' : 'SimpleStrategy', + 'replication_factor' : 1 + }`).Exec(); err != nil { + t.Fatal("CREATE KEYSPACE:", err) } - if _, err := db.Exec("USE gocql_wiki"); err != nil { - t.Fatal(err) + + if err := db.Query("USE gocql_wiki").Exec(); err != nil { + t.Fatal("USE:", err) } - if _, err := db.Exec(`CREATE TABLE page ( - title varchar, - revid timeuuid, - body varchar, - hits int, - protected boolean, - modified timestamp, - attachment blob, - PRIMARY KEY (title, revid) - )`); err != nil { - t.Fatal(err) + if err := db.Query(`CREATE TABLE page ( + title varchar, + revid int, + body varchar, + hits int, + protected boolean, + modified timestamp, + attachment blob, + PRIMARY KEY (title, revid) + )`).Exec(); err != nil { + t.Fatal("CREATE TABLE:", err) } + for _, p := range pages { - if _, err := db.Exec(`INSERT INTO page (title, revid, body, hits, - protected, modified, attachment) VALUES (?, ?, ?, ?, ?, ?, ?);`, + if err := db.Query(`INSERT INTO page (title, revid, body, hits, + protected, modified, attachment) VALUES (?, ?, ?, ?, ?, ?, ?)`, p.Title, p.RevID, p.Body, p.Hits, p.Protected, p.Modified, - p.Attachment); err != nil { - t.Fatal(err) + p.Attachment).Exec(); err != nil { + t.Fatal("INSERT:", err) } } - row := db.QueryRow(`SELECT count(*) FROM page`) var count int - if err := row.Scan(&count); err != nil { - t.Error(err) + if err := db.Query("SELECT count(*) FROM page").Scan(&count); err != nil { + t.Fatal("COUNT:", err) } if count != len(pages) { - t.Fatalf("expected %d rows, got %d", len(pages), count) + t.Fatalf("COUNT: expected %d got %d", len(pages), count) } for _, page := range pages { - row := db.QueryRow(`SELECT title, revid, body, hits, protected, - modified, attachment - FROM page WHERE title = ? AND revid = ?`, page.Title, page.RevID) + qry := db.Query(`SELECT title, revid, body, hits, protected, + modified, attachment + FROM page WHERE title = ? AND revid = ?`, page.Title, page.RevID) var p Page - err := row.Scan(&p.Title, &p.RevID, &p.Body, &p.Hits, &p.Protected, - &p.Modified, &p.Attachment) - if err != nil { - t.Fatal(err) + if err := qry.Scan(&p.Title, &p.RevID, &p.Body, &p.Hits, &p.Protected, + &p.Modified, &p.Attachment); err != nil { + t.Fatal("SELECT PAGE:", err) } p.Modified = p.Modified.In(time.UTC) if page.Title != p.Title || page.RevID != p.RevID || @@ -115,111 +121,4 @@ func TestWiki(t *testing.T) { t.Errorf("expected %#v got %#v", *page, p) } } - - row = db.QueryRow(`SELECT title, revid, body, hits, protected, - modified, attachment - FROM page WHERE title = ? ORDER BY revid DESC`, "Frontpage") - var p Page - if err := row.Scan(&p.Title, &p.RevID, &p.Body, &p.Hits, &p.Protected, - &p.Modified, &p.Attachment); err != nil { - t.Error(err) - } - p.Modified = p.Modified.In(time.UTC) - page := pages[1] - if page.Title != p.Title || page.RevID != p.RevID || - page.Body != p.Body || page.Modified != p.Modified || - page.Hits != p.Hits || page.Protected != p.Protected || - !bytes.Equal(page.Attachment, p.Attachment) { - t.Errorf("expected %#v got %#v", *page, p) - } - -} - -func TestTypes(t *testing.T) { - db, err := sql.Open("gocql", "localhost:9042 compression=snappy") - if err != nil { - t.Fatal(err) - } - db.Exec("DROP KEYSPACE gocql_types") - if _, err := db.Exec(`CREATE KEYSPACE gocql_types - WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor' : 1 }`); err != nil { - t.Fatal(err) - } - if _, err := db.Exec("USE gocql_types"); err != nil { - t.Fatal(err) - } - - if _, err := db.Exec(`CREATE TABLE stuff ( - id bigint, - foo text, - PRIMARY KEY (id) - )`); err != nil { - t.Fatal(err) - } - - id := int64(-1 << 63) - - if _, err := db.Exec(`INSERT INTO stuff (id, foo) VALUES (?, ?);`, &id, "test"); err != nil { - t.Fatal(err) - } - - var rid int64 - - row := db.QueryRow(`SELECT id FROM stuff WHERE id = ?`, id) - - if err := row.Scan(&rid); err != nil { - t.Error(err) - } - - if id != rid { - t.Errorf("expected %v got %v", id, rid) - } -} - - -func TestNullColumnValues(t *testing.T) { - db, err := sql.Open("gocql", "localhost:9042 compression=snappy") - if err != nil { - t.Fatal(err) - } - db.Exec("DROP KEYSPACE gocql_nullvalues") - if _, err := db.Exec(`CREATE KEYSPACE gocql_nullvalues - WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor' : 1 };`); err != nil { - t.Fatal(err) - } - if _, err := db.Exec("USE gocql_nullvalues"); err != nil { - t.Fatal(err) - } - if _, err := db.Exec(`CREATE TABLE stuff ( - id bigint, - subid bigint, - foo text, - bar text, - PRIMARY KEY (id, subid) - )`); err != nil { - t.Fatal(err) - } - id := int64(-1 << 63) - - if _, err := db.Exec(`INSERT INTO stuff (id, subid, foo) VALUES (?, ?, ?);`, id, int64(4), "test"); err != nil { - t.Fatal(err) - } - - if _, err := db.Exec(`INSERT INTO stuff (id, subid, bar) VALUES (?, ?, ?);`, id, int64(6), "test2"); err != nil { - t.Fatal(err) - } - - var rid int64 - var sid int64 - var data1 []byte - var data2 []byte - if rows, err := db.Query(`SELECT id, subid, foo, bar FROM stuff`); err == nil { - for rows.Next() { - if err := rows.Scan(&rid, &sid, &data1, &data2); err != nil { - t.Error(err) - } - } - } else { - t.Fatal(err) - } } diff --git a/marshal.go b/marshal.go new file mode 100644 index 000000000..6680e2ac7 --- /dev/null +++ b/marshal.go @@ -0,0 +1,204 @@ +// Copyright (c) 2012 The gocql Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package gocql + +import ( + "fmt" + "time" +) + +// Marshaler is the interface implemented by objects that can marshal +// themselves into values understood by Cassandra. +type Marshaler interface { + MarshalCQL(info *TypeInfo, value interface{}) ([]byte, error) +} + +// Unmarshaler is the interface implemented by objects that can unmarshal +// a Cassandra specific description of themselves. +type Unmarshaler interface { + UnmarshalCQL(info *TypeInfo, data []byte, value interface{}) error +} + +// Marshal returns the CQL encoding of the value for the Cassandra +// internal type described by the info parameter. +func Marshal(info *TypeInfo, value interface{}) ([]byte, error) { + if v, ok := value.(Marshaler); ok { + return v.MarshalCQL(info, value) + } + switch info.Type { + case TypeVarchar, TypeAscii, TypeBlob: + switch v := value.(type) { + case string: + return []byte(v), nil + case []byte: + return v, nil + } + case TypeBoolean: + if v, ok := value.(bool); ok { + if v { + return []byte{1}, nil + } else { + return []byte{0}, nil + } + } + case TypeInt: + switch v := value.(type) { + case int: + x := int32(v) + return []byte{byte(x >> 24), byte(x >> 16), byte(x >> 8), byte(x)}, nil + } + case TypeTimestamp: + if v, ok := value.(time.Time); ok { + x := v.In(time.UTC).UnixNano() / int64(time.Millisecond) + return []byte{byte(x >> 56), byte(x >> 48), byte(x >> 40), + byte(x >> 32), byte(x >> 24), byte(x >> 16), + byte(x >> 8), byte(x)}, nil + } + } + // TODO(tux21b): add reflection and a lot of other types + return nil, fmt.Errorf("can not marshal %T into %s", value, info) +} + +// Unmarshal parses the CQL encoded data based on the info parameter that +// describes the Cassandra internal data type and stores the result in the +// value pointed by value. +func Unmarshal(info *TypeInfo, data []byte, value interface{}) error { + if v, ok := value.(Unmarshaler); ok { + return v.UnmarshalCQL(info, data, value) + } + switch info.Type { + case TypeVarchar, TypeAscii, TypeBlob: + switch v := value.(type) { + case *string: + *v = string(data) + return nil + case *[]byte: + val := make([]byte, len(data)) + copy(val, data) + *v = val + return nil + } + case TypeBoolean: + if v, ok := value.(*bool); ok && len(data) == 1 { + *v = data[0] != 0 + return nil + } + case TypeBigInt: + if v, ok := value.(*int); ok && len(data) == 8 { + *v = int(data[0])<<56 | int(data[1])<<48 | int(data[2])<<40 | + int(data[3])<<32 | int(data[4])<<24 | int(data[5])<<16 | + int(data[6])<<8 | int(data[7]) + return nil + } + case TypeInt: + if v, ok := value.(*int); ok && len(data) == 4 { + *v = int(data[0])<<24 | int(data[1])<<16 | int(data[2])<<8 | + int(data[3]) + return nil + } + case TypeTimestamp: + if v, ok := value.(*time.Time); ok && len(data) == 8 { + x := int64(data[0])<<56 | int64(data[1])<<48 | + int64(data[2])<<40 | int64(data[3])<<32 | + int64(data[4])<<24 | int64(data[5])<<16 | + int64(data[6])<<8 | int64(data[7]) + sec := x / 1000 + nsec := (x - sec*1000) * 1000000 + *v = time.Unix(sec, nsec) + return nil + } + } + // TODO(tux21b): add reflection and a lot of other basic types + return fmt.Errorf("can not unmarshal %s into %T", info, value) +} + +// TypeInfo describes a Cassandra specific data type. +type TypeInfo struct { + Type Type + Key *TypeInfo // only used for TypeMap + Value *TypeInfo // only used for TypeMap, TypeList and TypeSet + Custom string // only used for TypeCostum +} + +// String returns a human readable name for the Cassandra datatype +// described by t. +func (t TypeInfo) String() string { + switch t.Type { + case TypeMap: + return fmt.Sprintf("%s(%s, %s)", t.Type, t.Key, t.Value) + case TypeList, TypeSet: + return fmt.Sprintf("%s(%s)", t.Type, t.Value) + case TypeCustom: + return fmt.Sprintf("%s(%s)", t.Type, t.Custom) + } + return t.Type.String() +} + +// Type is the identifier of a Cassandra internal datatype. +type Type int + +const ( + TypeCustom Type = 0x0000 + TypeAscii Type = 0x0001 + TypeBigInt Type = 0x0002 + TypeBlob Type = 0x0003 + TypeBoolean Type = 0x0004 + TypeCounter Type = 0x0005 + TypeDecimal Type = 0x0006 + TypeDouble Type = 0x0007 + TypeFloat Type = 0x0008 + TypeInt Type = 0x0009 + TypeTimestamp Type = 0x000B + TypeUUID Type = 0x000C + TypeVarchar Type = 0x000D + TypeVarint Type = 0x000E + TypeTimeUUID Type = 0x000F + TypeInet Type = 0x0010 + TypeList Type = 0x0020 + TypeMap Type = 0x0021 + TypeSet Type = 0x0022 +) + +// String returns the name of the identifier. +func (t Type) String() string { + switch t { + case TypeCustom: + return "custom" + case TypeAscii: + return "ascii" + case TypeBigInt: + return "bigint" + case TypeBlob: + return "blob" + case TypeBoolean: + return "boolean" + case TypeCounter: + return "counter" + case TypeDecimal: + return "decimal" + case TypeFloat: + return "float" + case TypeInt: + return "int" + case TypeTimestamp: + return "timestamp" + case TypeUUID: + return "uuid" + case TypeVarchar: + return "varchar" + case TypeTimeUUID: + return "timeuuid" + case TypeInet: + return "inet" + case TypeList: + return "list" + case TypeMap: + return "map" + case TypeSet: + return "set" + default: + return "unknown" + } +}