1
0
Fork 0
mirror of https://github.com/immich-app/immich.git synced 2025-01-04 02:46:47 +01:00

chore(server): startup check for pgvecto.rs (#5815)

* startup check for pgvecto.rs

* prefilter after assertion

* formatting

* add assert to migration

* more specific import

* use runner
This commit is contained in:
Mert 2023-12-18 11:38:25 -05:00 committed by GitHub
parent fade8b627f
commit de1514a441
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 53 additions and 15 deletions

View file

@ -1,5 +1,5 @@
import { envName, isDev, serverVersion } from '@app/domain'; import { envName, isDev, serverVersion } from '@app/domain';
import { WebSocketAdapter, enablePrefilter } from '@app/infra'; import { WebSocketAdapter, databaseChecks } from '@app/infra';
import { ImmichLogger } from '@app/infra/logger'; import { ImmichLogger } from '@app/infra/logger';
import { NestFactory } from '@nestjs/core'; import { NestFactory } from '@nestjs/core';
import { NestExpressApplication } from '@nestjs/platform-express'; import { NestExpressApplication } from '@nestjs/platform-express';
@ -31,7 +31,7 @@ export async function bootstrap() {
app.useStaticAssets('www'); app.useStaticAssets('www');
app.use(app.get(AppService).ssr(excludePaths)); app.use(app.get(AppService).ssr(excludePaths));
await enablePrefilter(); await databaseChecks();
const server = await app.listen(port); const server = await app.listen(port);
server.requestTimeout = 30 * 60 * 1000; server.requestTimeout = 30 * 60 * 1000;

View file

@ -1,4 +1,4 @@
import { DataSource } from 'typeorm'; import { DataSource, QueryRunner } from 'typeorm';
import { PostgresConnectionOptions } from 'typeorm/driver/postgres/PostgresConnectionOptions'; import { PostgresConnectionOptions } from 'typeorm/driver/postgres/PostgresConnectionOptions';
const url = process.env.DB_URL; const url = process.env.DB_URL;
@ -26,9 +26,50 @@ export const databaseConfig: PostgresConnectionOptions = {
// this export is used by TypeORM commands in package.json#scripts // this export is used by TypeORM commands in package.json#scripts
export const dataSource = new DataSource(databaseConfig); export const dataSource = new DataSource(databaseConfig);
export async function enablePrefilter() { export async function databaseChecks() {
if (!dataSource.isInitialized) { if (!dataSource.isInitialized) {
await dataSource.initialize(); await dataSource.initialize();
} }
await dataSource.query(`SET vectors.enable_prefilter = on`);
await assertVectors(dataSource);
await enablePrefilter(dataSource);
await dataSource.runMigrations();
}
export async function enablePrefilter(runner: DataSource | QueryRunner) {
await runner.query(`SET vectors.enable_prefilter = on`);
}
export async function getExtensionVersion(extName: string, runner: DataSource | QueryRunner): Promise<string | null> {
const res = await runner.query(`SELECT extversion FROM pg_extension WHERE extname = $1`, [extName]);
return res[0]?.['extversion'] ?? null;
}
export async function getPostgresVersion(runner: DataSource | QueryRunner): Promise<string> {
const res = await runner.query(`SHOW server_version`);
return res[0]['server_version'].split('.')[0];
}
export async function assertVectors(runner: DataSource | QueryRunner) {
const postgresVersion = await getPostgresVersion(runner);
const expected = ['0.1.1', '0.1.11'];
const image = `tensorchord/pgvecto-rs:pg${postgresVersion}-v${expected[expected.length - 1]}`;
await runner.query('CREATE EXTENSION IF NOT EXISTS vectors').catch((err) => {
console.error(
'Failed to create pgvecto.rs extension. ' +
`If you have not updated your Postgres instance to an image that supports pgvecto.rs (such as '${image}'), please do so. ` +
'See the v1.91.0 release notes for more info: https://github.com/immich-app/immich/releases/tag/v1.91.0',
);
throw err;
});
const version = await getExtensionVersion('vectors', runner);
if (version != null && !expected.includes(version)) {
throw new Error(
`The pgvecto.rs extension version is ${version} instead of the expected version ${
expected[expected.length - 1]
}.` + `If you're using the 'latest' tag, please switch to '${image}'.`,
);
}
} }

View file

@ -1,10 +1,13 @@
import { getCLIPModelInfo } from '@app/domain/smart-info/smart-info.constant'; import { getCLIPModelInfo } from '@app/domain/smart-info/smart-info.constant';
import { MigrationInterface, QueryRunner } from 'typeorm'; import { MigrationInterface, QueryRunner } from 'typeorm';
import { assertVectors } from '../database.config';
export class UsePgVectors1700713871511 implements MigrationInterface { export class UsePgVectors1700713871511 implements MigrationInterface {
name = 'UsePgVectors1700713871511'; name = 'UsePgVectors1700713871511';
public async up(queryRunner: QueryRunner): Promise<void> { public async up(queryRunner: QueryRunner): Promise<void> {
await assertVectors(queryRunner);
const faceDimQuery = await queryRunner.query(` const faceDimQuery = await queryRunner.query(`
SELECT CARDINALITY(embedding::real[]) as dimsize SELECT CARDINALITY(embedding::real[]) as dimsize
FROM asset_faces FROM asset_faces
@ -15,8 +18,6 @@ export class UsePgVectors1700713871511 implements MigrationInterface {
const clipModelName: string = clipModelNameQuery?.[0]?.['value'] ?? 'ViT-B-32__openai'; const clipModelName: string = clipModelNameQuery?.[0]?.['value'] ?? 'ViT-B-32__openai';
const clipDimSize = getCLIPModelInfo(clipModelName.replace(/"/g, '')).dimSize; const clipDimSize = getCLIPModelInfo(clipModelName.replace(/"/g, '')).dimSize;
await queryRunner.query('CREATE EXTENSION IF NOT EXISTS vectors');
await queryRunner.query(` await queryRunner.query(`
ALTER TABLE asset_faces ALTER TABLE asset_faces
ALTER COLUMN embedding SET NOT NULL, ALTER COLUMN embedding SET NOT NULL,

View file

@ -1,5 +1,5 @@
import { envName, serverVersion } from '@app/domain'; import { envName, serverVersion } from '@app/domain';
import { WebSocketAdapter, enablePrefilter } from '@app/infra'; import { WebSocketAdapter, databaseChecks } from '@app/infra';
import { ImmichLogger } from '@app/infra/logger'; import { ImmichLogger } from '@app/infra/logger';
import { NestFactory } from '@nestjs/core'; import { NestFactory } from '@nestjs/core';
import { MicroservicesModule } from './microservices.module'; import { MicroservicesModule } from './microservices.module';
@ -12,7 +12,7 @@ export async function bootstrap() {
app.useLogger(app.get(ImmichLogger)); app.useLogger(app.get(ImmichLogger));
app.useWebSocketAdapter(new WebSocketAdapter(app)); app.useWebSocketAdapter(new WebSocketAdapter(app));
await enablePrefilter(); await databaseChecks();
await app.listen(port); await app.listen(port);

View file

@ -1,6 +1,6 @@
import { AssetCreate, IJobRepository, JobItem, JobItemHandler, LibraryResponseDto, QueueName } from '@app/domain'; import { AssetCreate, IJobRepository, JobItem, JobItemHandler, LibraryResponseDto, QueueName } from '@app/domain';
import { AppModule } from '@app/immich'; import { AppModule } from '@app/immich';
import { dataSource } from '@app/infra'; import { dataSource, databaseChecks } from '@app/infra';
import { AssetEntity, AssetType, LibraryType } from '@app/infra/entities'; import { AssetEntity, AssetType, LibraryType } from '@app/infra/entities';
import { INestApplication } from '@nestjs/common'; import { INestApplication } from '@nestjs/common';
import { Test } from '@nestjs/testing'; import { Test } from '@nestjs/testing';
@ -22,11 +22,7 @@ export interface ResetOptions {
} }
export const db = { export const db = {
reset: async (options?: ResetOptions) => { reset: async (options?: ResetOptions) => {
if (!dataSource.isInitialized) { await databaseChecks();
await dataSource.initialize();
}
await dataSource.query(`SET vectors.enable_prefilter = on`);
await dataSource.transaction(async (em) => { await dataSource.transaction(async (em) => {
const entities = options?.entities || []; const entities = options?.entities || [];
const tableNames = const tableNames =