| package sqlite3 |
| |
| import ( |
| "database/sql" |
| "fmt" |
| "github.com/mattes/migrate" |
| "github.com/mattes/migrate/database" |
| _ "github.com/mattn/go-sqlite3" |
| "io" |
| "io/ioutil" |
| nurl "net/url" |
| "strings" |
| ) |
| |
| func init() { |
| database.Register("sqlite3", &Sqlite{}) |
| } |
| |
| var DefaultMigrationsTable = "schema_migrations" |
| var ( |
| ErrDatabaseDirty = fmt.Errorf("database is dirty") |
| ErrNilConfig = fmt.Errorf("no config") |
| ErrNoDatabaseName = fmt.Errorf("no database name") |
| ) |
| |
| type Config struct { |
| MigrationsTable string |
| DatabaseName string |
| } |
| |
| type Sqlite struct { |
| db *sql.DB |
| isLocked bool |
| |
| config *Config |
| } |
| |
| func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { |
| if config == nil { |
| return nil, ErrNilConfig |
| } |
| |
| if err := instance.Ping(); err != nil { |
| return nil, err |
| } |
| if len(config.MigrationsTable) == 0 { |
| config.MigrationsTable = DefaultMigrationsTable |
| } |
| |
| mx := &Sqlite{ |
| db: instance, |
| config: config, |
| } |
| if err := mx.ensureVersionTable(); err != nil { |
| return nil, err |
| } |
| return mx, nil |
| } |
| |
| func (m *Sqlite) ensureVersionTable() error { |
| |
| query := fmt.Sprintf(` |
| CREATE TABLE IF NOT EXISTS %s (version uint64,dirty bool); |
| CREATE UNIQUE INDEX IF NOT EXISTS version_unique ON %s (version); |
| `, DefaultMigrationsTable, DefaultMigrationsTable) |
| |
| if _, err := m.db.Exec(query); err != nil { |
| return err |
| } |
| return nil |
| } |
| |
| func (m *Sqlite) Open(url string) (database.Driver, error) { |
| purl, err := nurl.Parse(url) |
| if err != nil { |
| return nil, err |
| } |
| dbfile := strings.Replace(migrate.FilterCustomQuery(purl).String(), "sqlite3://", "", 1) |
| db, err := sql.Open("sqlite3", dbfile) |
| if err != nil { |
| return nil, err |
| } |
| |
| migrationsTable := purl.Query().Get("x-migrations-table") |
| if len(migrationsTable) == 0 { |
| migrationsTable = DefaultMigrationsTable |
| } |
| mx, err := WithInstance(db, &Config{ |
| DatabaseName: purl.Path, |
| MigrationsTable: migrationsTable, |
| }) |
| if err != nil { |
| return nil, err |
| } |
| return mx, nil |
| } |
| |
| func (m *Sqlite) Close() error { |
| return m.db.Close() |
| } |
| |
| func (m *Sqlite) Drop() error { |
| query := `SELECT name FROM sqlite_master WHERE type = 'table';` |
| tables, err := m.db.Query(query) |
| if err != nil { |
| return &database.Error{OrigErr: err, Query: []byte(query)} |
| } |
| defer tables.Close() |
| tableNames := make([]string, 0) |
| for tables.Next() { |
| var tableName string |
| if err := tables.Scan(&tableName); err != nil { |
| return err |
| } |
| if len(tableName) > 0 { |
| tableNames = append(tableNames, tableName) |
| } |
| } |
| if len(tableNames) > 0 { |
| for _, t := range tableNames { |
| query := "DROP TABLE " + t |
| err = m.executeQuery(query) |
| if err != nil { |
| return &database.Error{OrigErr: err, Query: []byte(query)} |
| } |
| } |
| if err := m.ensureVersionTable(); err != nil { |
| return err |
| } |
| query := "VACUUM" |
| _, err = m.db.Query(query) |
| if err != nil { |
| return &database.Error{OrigErr: err, Query: []byte(query)} |
| } |
| } |
| |
| return nil |
| } |
| |
| func (m *Sqlite) Lock() error { |
| if m.isLocked { |
| return database.ErrLocked |
| } |
| m.isLocked = true |
| return nil |
| } |
| |
| func (m *Sqlite) Unlock() error { |
| if !m.isLocked { |
| return nil |
| } |
| m.isLocked = false |
| return nil |
| } |
| |
| func (m *Sqlite) Run(migration io.Reader) error { |
| migr, err := ioutil.ReadAll(migration) |
| if err != nil { |
| return err |
| } |
| query := string(migr[:]) |
| |
| return m.executeQuery(query) |
| } |
| |
| func (m *Sqlite) executeQuery(query string) error { |
| tx, err := m.db.Begin() |
| if err != nil { |
| return &database.Error{OrigErr: err, Err: "transaction start failed"} |
| } |
| if _, err := tx.Exec(query); err != nil { |
| tx.Rollback() |
| return &database.Error{OrigErr: err, Query: []byte(query)} |
| } |
| if err := tx.Commit(); err != nil { |
| return &database.Error{OrigErr: err, Err: "transaction commit failed"} |
| } |
| return nil |
| } |
| |
| func (m *Sqlite) SetVersion(version int, dirty bool) error { |
| tx, err := m.db.Begin() |
| if err != nil { |
| return &database.Error{OrigErr: err, Err: "transaction start failed"} |
| } |
| |
| query := "DELETE FROM " + m.config.MigrationsTable |
| if _, err := tx.Exec(query); err != nil { |
| return &database.Error{OrigErr: err, Query: []byte(query)} |
| } |
| |
| if version >= 0 { |
| query := fmt.Sprintf(`INSERT INTO %s (version, dirty) VALUES (%d, '%t')`, m.config.MigrationsTable, version, dirty) |
| if _, err := tx.Exec(query); err != nil { |
| tx.Rollback() |
| return &database.Error{OrigErr: err, Query: []byte(query)} |
| } |
| } |
| |
| if err := tx.Commit(); err != nil { |
| return &database.Error{OrigErr: err, Err: "transaction commit failed"} |
| } |
| |
| return nil |
| } |
| |
| func (m *Sqlite) Version() (version int, dirty bool, err error) { |
| query := "SELECT version, dirty FROM " + m.config.MigrationsTable + " LIMIT 1" |
| err = m.db.QueryRow(query).Scan(&version, &dirty) |
| if err != nil { |
| return database.NilVersion, false, nil |
| } |
| return version, dirty, nil |
| } |