-
Notifications
You must be signed in to change notification settings - Fork 0
/
model.ts
105 lines (96 loc) · 2.89 KB
/
model.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import { mkdirSync } from 'fs'
import { join } from 'path'
import { getDirFilenamesSync } from '@beenotung/tslib/fs'
import {
PreTrainedImageModels,
loadImageClassifierModel,
loadImageModel,
} from 'tensorflow-helpers'
function createClassNameDirectories(dir: string, classNames: string[]) {
for (let className of classNames) {
mkdirSync(join(dir, className), { recursive: true })
}
}
export function getClassNames(): string[] {
mkdirSync('dataset', { recursive: true })
let classNames: string[] = getDirFilenamesSync('dataset')
if (classNames.length == 0) {
console.error('Error: no class names found in dataset directory')
console.error('Example: others, cat, dog, both')
process.exit(1)
}
return classNames
}
export async function loadModels() {
let classNames = getClassNames()
createClassNameDirectories('dataset', classNames)
createClassNameDirectories('classified', classNames)
let imageModelSpec = PreTrainedImageModels.mobilenet['mobilenet-v3-large-100']
let { db } = await import('./db')
let has_embedding = db
.prepare<string, number>(
/* sql */ `select (case when embedding is null then 0 else 1 end) as count from image where filename = ?`,
)
.pluck()
let select_embedding = db
.prepare<string, string | null>(
/* sql */ `select embedding from image where filename = ?`,
)
.pluck()
let update_embedding = db.prepare<
{ embedding: string; filename: string },
string | null
>(
/* sql */ `update image set embedding = :embedding where filename = :filename`,
)
let insert_embedding = db.prepare<
{ embedding: string; filename: string },
string | null
>(
/* sql */ `insert into image (filename, embedding) values (:filename, :embedding)`,
)
let select_cached_images = db
.prepare<void[], string>(
/* sql */ `
select filename from image where embedding is not null
`,
)
.pluck()
let embeddingCache = {
keys(): string[] {
return select_cached_images.all()
},
has(filename: string) {
return has_embedding.get(filename)! == 1
},
get(filename: string) {
let embedding = select_embedding.get(filename)
if (!embedding) return null
return embedding.split(',').map(s => +s)
},
set(filename: string, values: number[]) {
let embedding = values.join(',')
if (update_embedding.run({ filename, embedding }).changes == 1) {
return
}
insert_embedding.run({ filename, embedding })
},
}
let baseModel = await loadImageModel({
dir: './saved_models/base_model',
spec: imageModelSpec,
cache: embeddingCache,
})
let classifierModel = await loadImageClassifierModel({
modelDir: './saved_models/classifier_model',
datasetDir: './dataset',
baseModel,
classNames,
hiddenLayers: [imageModelSpec.features],
})
return {
embeddingCache,
baseModel,
classifierModel,
}
}