A Go package to make patching your Postgres database schema easier.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

285 lines
6.2 KiB

  1. // Copyright 2017 Andrew Chilton
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. package patchy
  15. import (
  16. "database/sql"
  17. "fmt"
  18. "io/ioutil"
  19. "log"
  20. "path"
  21. "strconv"
  22. "strings"
  23. _ "github.com/lib/pq"
  24. )
  25. var patchKey = "patch"
  26. type Options struct {
  27. Dir string
  28. PropertyTable string
  29. }
  30. type patch struct {
  31. Forward string
  32. Reverse string
  33. }
  34. func oneRowOneColBool(rows *sql.Rows) (bool, error) {
  35. var val bool
  36. var rowCount int
  37. for rows.Next() {
  38. rowCount++
  39. err := rows.Scan(&val)
  40. if err != nil {
  41. return val, err
  42. }
  43. }
  44. return val, rows.Err()
  45. }
  46. func insertPatchLevel(db *sql.DB, tableName string) error {
  47. sqlIns := `INSERT INTO ` + tableName + `(key, value) VALUES($1, 1)`
  48. _, err := db.Exec(sqlIns, patchKey)
  49. return err
  50. }
  51. func createPropertyTable(db *sql.DB, tableName string) error {
  52. sqlCreateTable := `
  53. CREATE TABLE ` + tableName + ` (
  54. key TEXT PRIMARY KEY,
  55. value TEXT
  56. );
  57. `
  58. _, err := db.Exec(sqlCreateTable)
  59. if err != nil {
  60. return err
  61. }
  62. err = insertPatchLevel(db, tableName)
  63. return err
  64. }
  65. func getCurrentLevel(db *sql.DB, tableName string) (int, error) {
  66. sqlSel := "SELECT value FROM " + tableName + " WHERE key = $1"
  67. fmt.Printf("sql=%s\n", sqlSel)
  68. fmt.Printf("key=%s\n", patchKey)
  69. var level int
  70. row := db.QueryRow(sqlSel, patchKey)
  71. if err := row.Scan(&level); err != nil {
  72. if err == sql.ErrNoRows {
  73. fmt.Printf("*** NO ROWS\n")
  74. err := insertPatchLevel(db, tableName)
  75. return 0, err
  76. }
  77. return 0, err
  78. }
  79. return level, nil
  80. }
  81. func Patch(db *sql.DB, level int, opts *Options) (int, error) {
  82. if opts == nil {
  83. opts = &Options{}
  84. }
  85. // set some option defaults
  86. if opts.Dir == "" {
  87. opts.Dir = "."
  88. }
  89. if opts.PropertyTable == "" {
  90. opts.PropertyTable = "property"
  91. }
  92. fmt.Printf("Options=%#v\n", opts)
  93. // read the dir list
  94. files, err := ioutil.ReadDir(opts.Dir)
  95. if err != nil {
  96. return 0, err
  97. }
  98. // check the files are what we expect
  99. patchSet := make(map[int]*patch)
  100. for _, file := range files {
  101. fmt.Printf("file=%s\n", file.Name())
  102. patchDirection := ""
  103. if strings.HasSuffix(file.Name(), "-forward.sql") {
  104. fmt.Printf("- forward\n")
  105. patchDirection = "forward"
  106. } else if strings.HasSuffix(file.Name(), "-reverse.sql") {
  107. fmt.Printf("- reverse\n")
  108. patchDirection = "reverse"
  109. } else {
  110. fmt.Printf("- unknown\n")
  111. continue
  112. }
  113. // what is the patch level of this filename
  114. patchLevel := file.Name()
  115. patchLevel = strings.TrimSuffix(patchLevel, "-"+patchDirection+".sql")
  116. // check this is a number
  117. n, err := strconv.Atoi(patchLevel)
  118. if err != nil {
  119. return 0, err
  120. }
  121. // read this file in
  122. if _, ok := patchSet[n]; ok {
  123. fmt.Printf("patch already exists, so adding the other direction (hopefully)\n")
  124. } else {
  125. fmt.Printf("no such patch in patchset yet - adding\n")
  126. patchSet[n] = &patch{}
  127. }
  128. filename := path.Join(opts.Dir, file.Name())
  129. sql, err := ioutil.ReadFile(filename)
  130. if err != nil {
  131. return 0, err
  132. }
  133. p := patchSet[n]
  134. if patchDirection == "forward" {
  135. p.Forward = string(sql)
  136. }
  137. if patchDirection == "reverse" {
  138. p.Reverse = string(sql)
  139. }
  140. fmt.Printf("p=%#v\n", p)
  141. }
  142. fmt.Printf("PatchSet=%v\n", patchSet)
  143. // ToDo: check that we have all patches (both Forward and Reverse) up to the level required
  144. // firstly, figure out if the property table exists
  145. sqlPropertyTableExists := `
  146. SELECT EXISTS (
  147. SELECT
  148. *
  149. FROM
  150. information_schema.tables
  151. WHERE
  152. table_schema = 'public'
  153. AND
  154. table_name = $1
  155. );
  156. `
  157. // check to see if the property table exists
  158. var propertyTableExists bool
  159. row := db.QueryRow(sqlPropertyTableExists, opts.PropertyTable)
  160. if err := row.Scan(&propertyTableExists); err != nil {
  161. return 0, err
  162. }
  163. if propertyTableExists == false {
  164. fmt.Printf("Creating property table\n")
  165. err := createPropertyTable(db, opts.PropertyTable)
  166. if err != nil {
  167. return 0, err
  168. }
  169. }
  170. // current patch level
  171. currentLevel, err := getCurrentLevel(db, opts.PropertyTable)
  172. if err != nil {
  173. return 0, err
  174. }
  175. fmt.Printf("-> Current Level = %d\n", currentLevel)
  176. // figure out which direction we're actually going to go in
  177. direction := ""
  178. step := 0
  179. if currentLevel < level {
  180. direction = "forward"
  181. step = 1
  182. }
  183. if currentLevel > level {
  184. direction = "reverse"
  185. step = -1
  186. }
  187. if currentLevel == level {
  188. fmt.Printf("Nothing to do, currently at the same level %d\n", level)
  189. return level, nil
  190. }
  191. fmt.Printf("-> Direction = %s\n", direction)
  192. for num := currentLevel; num != level; num += step {
  193. fmt.Printf("- doing from %d to %d\n", num, num+step)
  194. // loop through all of the patches we know about
  195. fmt.Printf("-> BEGIN ...\n")
  196. tx, err := db.Begin()
  197. if err != nil {
  198. return 0, err
  199. }
  200. fmt.Printf("-> BEGIN done\n")
  201. // update with this patch
  202. sql := ""
  203. if direction == "forward" {
  204. sql = patchSet[num+step].Forward
  205. }
  206. if direction == "reverse" {
  207. sql = patchSet[num].Reverse
  208. }
  209. fmt.Printf("-----> SQL = %s\n", sql)
  210. _, err = db.Exec(sql)
  211. if err != nil {
  212. fmt.Printf("-> ROLLBACK ...\n")
  213. err2 := tx.Rollback()
  214. if err2 != nil {
  215. log.Fatal(err2)
  216. }
  217. fmt.Printf("-> ROLLBACK done\n")
  218. return num, err
  219. }
  220. // update the property table
  221. _, err = db.Exec(`UPDATE property SET value = $1 WHERE key = $2`, num+step, patchKey)
  222. if err != nil {
  223. fmt.Printf("-> ROLLBACK ...\n")
  224. err2 := tx.Rollback()
  225. if err2 != nil {
  226. log.Fatal(err2)
  227. }
  228. fmt.Printf("-> ROLLBACK done\n")
  229. return num, err
  230. }
  231. // commit
  232. fmt.Printf("-> COMMIT ...\n")
  233. err = tx.Commit()
  234. if err != nil {
  235. err2 := tx.Rollback()
  236. if err2 != nil {
  237. log.Fatal(err2)
  238. }
  239. return 0, err
  240. }
  241. fmt.Printf("-> COMMIT done\n")
  242. }
  243. return level, nil
  244. }