diff --git a/server/src/immich/main.ts b/server/src/immich/main.ts index 84bc5bd1c2..2b2158ce62 100644 --- a/server/src/immich/main.ts +++ b/server/src/immich/main.ts @@ -1,5 +1,5 @@ import { envName, isDev, serverVersion } from '@app/domain'; -import { WebSocketAdapter, enablePrefilter } from '@app/infra'; +import { WebSocketAdapter, databaseChecks } from '@app/infra'; import { ImmichLogger } from '@app/infra/logger'; import { NestFactory } from '@nestjs/core'; import { NestExpressApplication } from '@nestjs/platform-express'; @@ -31,7 +31,7 @@ export async function bootstrap() { app.useStaticAssets('www'); app.use(app.get(AppService).ssr(excludePaths)); - await enablePrefilter(); + await databaseChecks(); const server = await app.listen(port); server.requestTimeout = 30 * 60 * 1000; diff --git a/server/src/infra/database.config.ts b/server/src/infra/database.config.ts index 191c428337..eb41b17bb3 100644 --- a/server/src/infra/database.config.ts +++ b/server/src/infra/database.config.ts @@ -1,4 +1,4 @@ -import { DataSource } from 'typeorm'; +import { DataSource, QueryRunner } from 'typeorm'; import { PostgresConnectionOptions } from 'typeorm/driver/postgres/PostgresConnectionOptions'; const url = process.env.DB_URL; @@ -26,9 +26,50 @@ export const databaseConfig: PostgresConnectionOptions = { // this export is used by TypeORM commands in package.json#scripts export const dataSource = new DataSource(databaseConfig); -export async function enablePrefilter() { +export async function databaseChecks() { if (!dataSource.isInitialized) { await dataSource.initialize(); } - await dataSource.query(`SET vectors.enable_prefilter = on`); + + await assertVectors(dataSource); + await enablePrefilter(dataSource); + await dataSource.runMigrations(); +} + +export async function enablePrefilter(runner: DataSource | QueryRunner) { + await runner.query(`SET vectors.enable_prefilter = on`); +} + +export async function getExtensionVersion(extName: string, runner: DataSource | QueryRunner): Promise { + const res = await runner.query(`SELECT extversion FROM pg_extension WHERE extname = $1`, [extName]); + return res[0]?.['extversion'] ?? null; +} + +export async function getPostgresVersion(runner: DataSource | QueryRunner): Promise { + const res = await runner.query(`SHOW server_version`); + return res[0]['server_version'].split('.')[0]; +} + +export async function assertVectors(runner: DataSource | QueryRunner) { + const postgresVersion = await getPostgresVersion(runner); + const expected = ['0.1.1', '0.1.11']; + const image = `tensorchord/pgvecto-rs:pg${postgresVersion}-v${expected[expected.length - 1]}`; + + await runner.query('CREATE EXTENSION IF NOT EXISTS vectors').catch((err) => { + console.error( + 'Failed to create pgvecto.rs extension. ' + + `If you have not updated your Postgres instance to an image that supports pgvecto.rs (such as '${image}'), please do so. ` + + 'See the v1.91.0 release notes for more info: https://github.com/immich-app/immich/releases/tag/v1.91.0', + ); + throw err; + }); + + const version = await getExtensionVersion('vectors', runner); + if (version != null && !expected.includes(version)) { + throw new Error( + `The pgvecto.rs extension version is ${version} instead of the expected version ${ + expected[expected.length - 1] + }.` + `If you're using the 'latest' tag, please switch to '${image}'.`, + ); + } } diff --git a/server/src/infra/migrations/1700713871511-UsePgVectors.ts b/server/src/infra/migrations/1700713871511-UsePgVectors.ts index 96882a0cf4..9b13f83643 100644 --- a/server/src/infra/migrations/1700713871511-UsePgVectors.ts +++ b/server/src/infra/migrations/1700713871511-UsePgVectors.ts @@ -1,10 +1,13 @@ import { getCLIPModelInfo } from '@app/domain/smart-info/smart-info.constant'; import { MigrationInterface, QueryRunner } from 'typeorm'; +import { assertVectors } from '../database.config'; export class UsePgVectors1700713871511 implements MigrationInterface { name = 'UsePgVectors1700713871511'; public async up(queryRunner: QueryRunner): Promise { + await assertVectors(queryRunner); + const faceDimQuery = await queryRunner.query(` SELECT CARDINALITY(embedding::real[]) as dimsize FROM asset_faces @@ -15,8 +18,6 @@ export class UsePgVectors1700713871511 implements MigrationInterface { const clipModelName: string = clipModelNameQuery?.[0]?.['value'] ?? 'ViT-B-32__openai'; const clipDimSize = getCLIPModelInfo(clipModelName.replace(/"/g, '')).dimSize; - await queryRunner.query('CREATE EXTENSION IF NOT EXISTS vectors'); - await queryRunner.query(` ALTER TABLE asset_faces ALTER COLUMN embedding SET NOT NULL, diff --git a/server/src/microservices/main.ts b/server/src/microservices/main.ts index c7e0662800..c50fa94252 100644 --- a/server/src/microservices/main.ts +++ b/server/src/microservices/main.ts @@ -1,5 +1,5 @@ import { envName, serverVersion } from '@app/domain'; -import { WebSocketAdapter, enablePrefilter } from '@app/infra'; +import { WebSocketAdapter, databaseChecks } from '@app/infra'; import { ImmichLogger } from '@app/infra/logger'; import { NestFactory } from '@nestjs/core'; import { MicroservicesModule } from './microservices.module'; @@ -12,7 +12,7 @@ export async function bootstrap() { app.useLogger(app.get(ImmichLogger)); app.useWebSocketAdapter(new WebSocketAdapter(app)); - await enablePrefilter(); + await databaseChecks(); await app.listen(port); diff --git a/server/test/test-utils.ts b/server/test/test-utils.ts index 5e99b535ab..c870123b09 100644 --- a/server/test/test-utils.ts +++ b/server/test/test-utils.ts @@ -1,6 +1,6 @@ import { AssetCreate, IJobRepository, JobItem, JobItemHandler, LibraryResponseDto, QueueName } from '@app/domain'; import { AppModule } from '@app/immich'; -import { dataSource } from '@app/infra'; +import { dataSource, databaseChecks } from '@app/infra'; import { AssetEntity, AssetType, LibraryType } from '@app/infra/entities'; import { INestApplication } from '@nestjs/common'; import { Test } from '@nestjs/testing'; @@ -22,11 +22,7 @@ export interface ResetOptions { } export const db = { reset: async (options?: ResetOptions) => { - if (!dataSource.isInitialized) { - await dataSource.initialize(); - } - - await dataSource.query(`SET vectors.enable_prefilter = on`); + await databaseChecks(); await dataSource.transaction(async (em) => { const entities = options?.entities || []; const tableNames =