diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 58761eea05..d9e7df74ad 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -275,7 +275,7 @@ jobs: runs-on: ubuntu-latest services: postgres: - image: tensorchord/pgvecto-rs:pg14-v0.1.11@sha256:0335a1a22f8c5dd1b697f14f079934f5152eaaa216c09b61e293be285491f8ee + image: tensorchord/pgvecto-rs:pg14-v0.2.0 env: POSTGRES_PASSWORD: postgres POSTGRES_USER: postgres diff --git a/cli/test/e2e/setup.ts b/cli/test/e2e/setup.ts index 52b2ae082c..fb1d939eba 100644 --- a/cli/test/e2e/setup.ts +++ b/cli/test/e2e/setup.ts @@ -25,7 +25,7 @@ export default async () => { if (process.env.DB_HOSTNAME === undefined) { // DB hostname not set which likely means we're not running e2e through docker compose. Start a local postgres container. - const pg = await new PostgreSqlContainer('tensorchord/pgvecto-rs:pg14-v0.1.11') + const pg = await new PostgreSqlContainer('tensorchord/pgvecto-rs:pg14-v0.2.0') .withExposedPorts(5432) .withDatabase('immich') .withUsername('postgres') diff --git a/docker/docker-compose.dev.yml b/docker/docker-compose.dev.yml index 5290e990e2..e380be4965 100644 --- a/docker/docker-compose.dev.yml +++ b/docker/docker-compose.dev.yml @@ -103,7 +103,7 @@ services: database: container_name: immich_postgres - image: tensorchord/pgvecto-rs:pg14-v0.1.11@sha256:0335a1a22f8c5dd1b697f14f079934f5152eaaa216c09b61e293be285491f8ee + image: tensorchord/pgvecto-rs:pg14-v0.2.0 env_file: - .env environment: diff --git a/docker/docker-compose.prod.yml b/docker/docker-compose.prod.yml index 04215b757b..857aa31540 100644 --- a/docker/docker-compose.prod.yml +++ b/docker/docker-compose.prod.yml @@ -61,7 +61,7 @@ services: database: container_name: immich_postgres - image: tensorchord/pgvecto-rs:pg14-v0.1.11@sha256:0335a1a22f8c5dd1b697f14f079934f5152eaaa216c09b61e293be285491f8ee + image: tensorchord/pgvecto-rs:pg14-v0.2.0 env_file: - .env environment: @@ -70,7 +70,8 @@ services: POSTGRES_DB: ${DB_DATABASE_NAME} volumes: - ${UPLOAD_LOCATION}/postgres:/var/lib/postgresql/data - restart: always + ports: + - 5432:5432 volumes: model-cache: diff --git a/docker/docker-compose.yml b/docker/docker-compose.yml index a6e6aa26ff..a9d8e2b38c 100644 --- a/docker/docker-compose.yml +++ b/docker/docker-compose.yml @@ -65,7 +65,7 @@ services: database: container_name: immich_postgres - image: tensorchord/pgvecto-rs:pg14-v0.1.11@sha256:0335a1a22f8c5dd1b697f14f079934f5152eaaa216c09b61e293be285491f8ee + image: tensorchord/pgvecto-rs:pg14-v0.2.0 env_file: - .env environment: diff --git a/server/e2e/api/setup.ts b/server/e2e/api/setup.ts index 7223a1c028..8d44b07cbb 100644 --- a/server/e2e/api/setup.ts +++ b/server/e2e/api/setup.ts @@ -1,7 +1,7 @@ import { PostgreSqlContainer } from '@testcontainers/postgresql'; export default async () => { - const pg = await new PostgreSqlContainer('tensorchord/pgvecto-rs:pg14-v0.1.11') + const pg = await new PostgreSqlContainer('tensorchord/pgvecto-rs:pg14-v0.2.0') .withDatabase('immich') .withUsername('postgres') .withPassword('postgres') diff --git a/server/e2e/docker-compose.server-e2e.yml b/server/e2e/docker-compose.server-e2e.yml index c9d656cedb..708ae2ca34 100644 --- a/server/e2e/docker-compose.server-e2e.yml +++ b/server/e2e/docker-compose.server-e2e.yml @@ -21,7 +21,7 @@ services: - database database: - image: tensorchord/pgvecto-rs:pg14-v0.1.11@sha256:0335a1a22f8c5dd1b697f14f079934f5152eaaa216c09b61e293be285491f8ee + image: tensorchord/pgvecto-rs:pg14-v0.2.0 command: -c fsync=off -c shared_preload_libraries=vectors.so environment: POSTGRES_PASSWORD: postgres diff --git a/server/e2e/jobs/setup.ts b/server/e2e/jobs/setup.ts index 601d99cc28..d1f566d372 100644 --- a/server/e2e/jobs/setup.ts +++ b/server/e2e/jobs/setup.ts @@ -25,7 +25,7 @@ export default async () => { if (process.env.DB_HOSTNAME === undefined) { // DB hostname not set which likely means we're not running e2e through docker compose. Start a local postgres container. - const pg = await new PostgreSqlContainer('tensorchord/pgvecto-rs:pg14-v0.1.11') + const pg = await new PostgreSqlContainer('tensorchord/pgvecto-rs:pg14-v0.2.0') .withExposedPorts(5432) .withDatabase('immich') .withUsername('postgres') diff --git a/server/src/domain/database/database.service.spec.ts b/server/src/domain/database/database.service.spec.ts index 7608f61111..703805b065 100644 --- a/server/src/domain/database/database.service.spec.ts +++ b/server/src/domain/database/database.service.spec.ts @@ -1,41 +1,65 @@ -import { DatabaseExtension, DatabaseService, IDatabaseRepository, Version } from '@app/domain'; +import { + DatabaseExtension, + DatabaseService, + IDatabaseRepository, + VectorIndex, + Version, + VersionType, +} from '@app/domain'; import { ImmichLogger } from '@app/infra/logger'; import { newDatabaseRepositoryMock } from '@test'; describe(DatabaseService.name, () => { let sut: DatabaseService; let databaseMock: jest.Mocked; - let fatalLog: jest.SpyInstance; beforeEach(async () => { databaseMock = newDatabaseRepositoryMock(); - fatalLog = jest.spyOn(ImmichLogger.prototype, 'fatal'); sut = new DatabaseService(databaseMock); - - sut.minVectorsVersion = new Version(0, 1, 1); - sut.maxVectorsVersion = new Version(0, 1, 11); - }); - - afterEach(() => { - fatalLog.mockRestore(); }); it('should work', () => { expect(sut).toBeDefined(); }); - describe('init', () => { - it('should resolve successfully if minimum supported PostgreSQL and vectors version are installed', async () => { + describe.each([ + [{ vectorExt: DatabaseExtension.VECTORS, extName: 'pgvecto.rs', minVersion: new Version(0, 1, 1) }], + [{ vectorExt: DatabaseExtension.VECTOR, extName: 'pgvector', minVersion: new Version(0, 5, 0) }], + ] as const)('init', ({ vectorExt, extName, minVersion }) => { + let fatalLog: jest.SpyInstance; + let errorLog: jest.SpyInstance; + let warnLog: jest.SpyInstance; + + beforeEach(async () => { + fatalLog = jest.spyOn(ImmichLogger.prototype, 'fatal'); + errorLog = jest.spyOn(ImmichLogger.prototype, 'error'); + warnLog = jest.spyOn(ImmichLogger.prototype, 'warn'); + databaseMock.getPreferredVectorExtension.mockReturnValue(vectorExt); + databaseMock.getExtensionVersion.mockResolvedValue(minVersion); + + sut = new DatabaseService(databaseMock); + + sut.minVectorVersion = minVersion; + sut.minVectorsVersion = minVersion; + sut.vectorVersionPin = VersionType.MINOR; + sut.vectorsVersionPin = VersionType.MINOR; + }); + + afterEach(() => { + fatalLog.mockRestore(); + warnLog.mockRestore(); + }); + + it(`should resolve successfully if minimum supported PostgreSQL and ${extName} version are installed`, async () => { databaseMock.getPostgresVersion.mockResolvedValueOnce(new Version(14, 0, 0)); - databaseMock.getExtensionVersion.mockResolvedValueOnce(new Version(0, 1, 1)); await expect(sut.init()).resolves.toBeUndefined(); - expect(databaseMock.getPostgresVersion).toHaveBeenCalledTimes(2); - expect(databaseMock.createExtension).toHaveBeenCalledWith(DatabaseExtension.VECTORS); + expect(databaseMock.getPostgresVersion).toHaveBeenCalled(); + expect(databaseMock.createExtension).toHaveBeenCalledWith(vectorExt); expect(databaseMock.createExtension).toHaveBeenCalledTimes(1); - expect(databaseMock.getExtensionVersion).toHaveBeenCalledTimes(1); + expect(databaseMock.getExtensionVersion).toHaveBeenCalled(); expect(databaseMock.runMigrations).toHaveBeenCalledTimes(1); expect(fatalLog).not.toHaveBeenCalled(); }); @@ -43,112 +67,162 @@ describe(DatabaseService.name, () => { it('should throw an error if PostgreSQL version is below minimum supported version', async () => { databaseMock.getPostgresVersion.mockResolvedValueOnce(new Version(13, 0, 0)); - await expect(sut.init()).rejects.toThrow(/PostgreSQL version is 13/s); + await expect(sut.init()).rejects.toThrow('PostgreSQL version is 13'); expect(databaseMock.getPostgresVersion).toHaveBeenCalledTimes(1); }); - it('should resolve successfully if minimum supported vectors version is installed', async () => { - databaseMock.getExtensionVersion.mockResolvedValueOnce(new Version(0, 1, 1)); - + it(`should resolve successfully if minimum supported ${extName} version is installed`, async () => { await expect(sut.init()).resolves.toBeUndefined(); - expect(databaseMock.createExtension).toHaveBeenCalledWith(DatabaseExtension.VECTORS); + expect(databaseMock.createExtension).toHaveBeenCalledWith(vectorExt); expect(databaseMock.createExtension).toHaveBeenCalledTimes(1); - expect(databaseMock.getExtensionVersion).toHaveBeenCalledTimes(1); expect(databaseMock.runMigrations).toHaveBeenCalledTimes(1); expect(fatalLog).not.toHaveBeenCalled(); }); - it('should resolve successfully if maximum supported vectors version is installed', async () => { - databaseMock.getExtensionVersion.mockResolvedValueOnce(new Version(0, 1, 11)); + it(`should throw an error if ${extName} version is not installed even after createVectorExtension`, async () => { + databaseMock.getExtensionVersion.mockResolvedValue(null); - await expect(sut.init()).resolves.toBeUndefined(); + await expect(sut.init()).rejects.toThrow(`Unexpected: ${extName} extension is not installed.`); - expect(databaseMock.createExtension).toHaveBeenCalledWith(DatabaseExtension.VECTORS); - expect(databaseMock.createExtension).toHaveBeenCalledTimes(1); - expect(databaseMock.getExtensionVersion).toHaveBeenCalledTimes(1); - expect(databaseMock.runMigrations).toHaveBeenCalledTimes(1); - expect(fatalLog).not.toHaveBeenCalled(); - }); - - it('should throw an error if vectors version is not installed even after createVectors', async () => { - databaseMock.getExtensionVersion.mockResolvedValueOnce(null); - - await expect(sut.init()).rejects.toThrow('Unexpected: The pgvecto.rs extension is not installed.'); - - expect(databaseMock.getExtensionVersion).toHaveBeenCalledTimes(1); expect(databaseMock.createExtension).toHaveBeenCalledTimes(1); expect(databaseMock.runMigrations).not.toHaveBeenCalled(); }); - it('should throw an error if vectors version is below minimum supported version', async () => { - databaseMock.getExtensionVersion.mockResolvedValueOnce(new Version(0, 0, 1)); - - await expect(sut.init()).rejects.toThrow(/('tensorchord\/pgvecto-rs:pg14-v0.1.11')/s); - - expect(databaseMock.getExtensionVersion).toHaveBeenCalledTimes(1); - expect(databaseMock.runMigrations).not.toHaveBeenCalled(); - }); - - it('should throw an error if vectors version is above maximum supported version', async () => { - databaseMock.getExtensionVersion.mockResolvedValueOnce(new Version(0, 1, 12)); - - await expect(sut.init()).rejects.toThrow( - /('DROP EXTENSION IF EXISTS vectors').*('tensorchord\/pgvecto-rs:pg14-v0\.1\.11')/s, + it(`should throw an error if ${extName} version is below minimum supported version`, async () => { + databaseMock.getExtensionVersion.mockResolvedValue( + new Version(minVersion.major, minVersion.minor - 1, minVersion.patch), ); - expect(databaseMock.getExtensionVersion).toHaveBeenCalledTimes(1); + await expect(sut.init()).rejects.toThrow(extName); + expect(databaseMock.runMigrations).not.toHaveBeenCalled(); }); - it('should throw an error if vectors version is a nightly', async () => { - databaseMock.getExtensionVersion.mockResolvedValueOnce(new Version(0, 0, 0)); + it.each([ + { type: VersionType.EQUAL, max: 'no', actual: 'patch' }, + { type: VersionType.PATCH, max: 'patch', actual: 'minor' }, + { type: VersionType.MINOR, max: 'minor', actual: 'major' }, + ] as const)( + `should throw an error if $max upgrade from min version is allowed and ${extName} version is $actual`, + async ({ type, actual }) => { + const version = new Version(minVersion.major, minVersion.minor, minVersion.patch); + version[actual] = minVersion[actual] + 1; + databaseMock.getExtensionVersion.mockResolvedValue(version); + if (vectorExt === DatabaseExtension.VECTOR) { + sut.minVectorVersion = minVersion; + sut.vectorVersionPin = type; + } else { + sut.minVectorsVersion = minVersion; + sut.vectorsVersionPin = type; + } - await expect(sut.init()).rejects.toThrow( - /(nightly).*('DROP EXTENSION IF EXISTS vectors').*('tensorchord\/pgvecto-rs:pg14-v0\.1\.11')/s, - ); + await expect(sut.init()).rejects.toThrow(extName); + + expect(databaseMock.runMigrations).not.toHaveBeenCalled(); + }, + ); + + it(`should throw an error if ${extName} version is a nightly`, async () => { + databaseMock.getExtensionVersion.mockResolvedValue(new Version(0, 0, 0)); + + await expect(sut.init()).rejects.toThrow(extName); - expect(databaseMock.getExtensionVersion).toHaveBeenCalledTimes(1); expect(databaseMock.createExtension).toHaveBeenCalledTimes(1); expect(databaseMock.runMigrations).not.toHaveBeenCalled(); }); - it('should throw error if vectors extension could not be created', async () => { - databaseMock.createExtension.mockRejectedValueOnce(new Error('Failed to create extension')); + it(`should throw error if ${extName} extension could not be created`, async () => { + databaseMock.createExtension.mockRejectedValue(new Error('Failed to create extension')); await expect(sut.init()).rejects.toThrow('Failed to create extension'); expect(fatalLog).toHaveBeenCalledTimes(1); - expect(fatalLog.mock.calls[0][0]).toMatch(/('tensorchord\/pgvecto-rs:pg14-v0\.1\.11').*(v1\.91\.0)/s); expect(databaseMock.createExtension).toHaveBeenCalledTimes(1); expect(databaseMock.runMigrations).not.toHaveBeenCalled(); }); - it.each([{ major: 14 }, { major: 15 }, { major: 16 }])( - `should suggest image with postgres $major if database is $major`, - async ({ major }) => { - databaseMock.getExtensionVersion.mockResolvedValue(new Version(0, 0, 1)); - databaseMock.getPostgresVersion.mockResolvedValue(new Version(major, 0, 0)); + it(`should update ${extName} if a newer version is available`, async () => { + const version = new Version(minVersion.major, minVersion.minor + 1, minVersion.patch); + databaseMock.getAvailableExtensionVersion.mockResolvedValue(version); - await expect(sut.init()).rejects.toThrow(new RegExp(`tensorchord\/pgvecto-rs:pg${major}-v0\\.1\\.11`, 's')); + await expect(sut.init()).resolves.toBeUndefined(); + + expect(databaseMock.updateVectorExtension).toHaveBeenCalledWith(vectorExt, version); + expect(databaseMock.updateVectorExtension).toHaveBeenCalledTimes(1); + expect(databaseMock.runMigrations).toHaveBeenCalledTimes(1); + expect(fatalLog).not.toHaveBeenCalled(); + }); + + it(`should not update ${extName} if a newer version is higher than the maximum`, async () => { + const version = new Version(minVersion.major + 1, minVersion.minor, minVersion.patch); + databaseMock.getAvailableExtensionVersion.mockResolvedValue(version); + + await expect(sut.init()).resolves.toBeUndefined(); + + expect(databaseMock.updateVectorExtension).not.toHaveBeenCalled(); + expect(databaseMock.runMigrations).toHaveBeenCalledTimes(1); + expect(fatalLog).not.toHaveBeenCalled(); + }); + + it(`should warn if attempted to update ${extName} and failed`, async () => { + const version = new Version(minVersion.major, minVersion.minor, minVersion.patch + 1); + databaseMock.getAvailableExtensionVersion.mockResolvedValue(version); + databaseMock.updateVectorExtension.mockRejectedValue(new Error('Failed to update extension')); + + await expect(sut.init()).resolves.toBeUndefined(); + + expect(warnLog).toHaveBeenCalledTimes(1); + expect(warnLog.mock.calls[0][0]).toContain(extName); + expect(errorLog).toHaveBeenCalledTimes(1); + expect(fatalLog).not.toHaveBeenCalled(); + expect(databaseMock.updateVectorExtension).toHaveBeenCalledWith(vectorExt, version); + expect(databaseMock.runMigrations).toHaveBeenCalledTimes(1); + }); + + it(`should warn if ${extName} update requires restart`, async () => { + const version = new Version(minVersion.major, minVersion.minor, minVersion.patch + 1); + databaseMock.getAvailableExtensionVersion.mockResolvedValue(version); + databaseMock.updateVectorExtension.mockResolvedValue({ restartRequired: true }); + + await expect(sut.init()).resolves.toBeUndefined(); + + expect(warnLog).toHaveBeenCalledTimes(1); + expect(warnLog.mock.calls[0][0]).toContain(extName); + expect(databaseMock.updateVectorExtension).toHaveBeenCalledWith(vectorExt, version); + expect(databaseMock.runMigrations).toHaveBeenCalledTimes(1); + expect(fatalLog).not.toHaveBeenCalled(); + }); + + it.each([{ index: VectorIndex.CLIP }, { index: VectorIndex.FACE }])( + `should reindex $index if necessary`, + async ({ index }) => { + databaseMock.shouldReindex.mockImplementation((indexArg) => Promise.resolve(indexArg === index)); + + await expect(sut.init()).resolves.toBeUndefined(); + + expect(databaseMock.shouldReindex).toHaveBeenCalledWith(index); + expect(databaseMock.shouldReindex).toHaveBeenCalledTimes(2); + expect(databaseMock.reindex).toHaveBeenCalledWith(index); + expect(databaseMock.reindex).toHaveBeenCalledTimes(1); + expect(databaseMock.runMigrations).toHaveBeenCalledTimes(1); + expect(fatalLog).not.toHaveBeenCalled(); }, ); - it('should not suggest image if postgres version is not in 14, 15 or 16', async () => { - databaseMock.getPostgresVersion.mockResolvedValueOnce(new Version(17, 0, 0)); - databaseMock.getPostgresVersion.mockResolvedValueOnce(new Version(17, 0, 0)); + it.each([{ index: VectorIndex.CLIP }, { index: VectorIndex.FACE }])( + `should not reindex $index if not necessary`, + async () => { + databaseMock.shouldReindex.mockResolvedValue(false); - await expect(sut.init()).rejects.toThrow(/^(?:(?!tensorchord\/pgvecto-rs).)*$/s); - }); + await expect(sut.init()).resolves.toBeUndefined(); - it('should reject and suggest the maximum supported version when unsupported pgvecto.rs version is in use', async () => { - databaseMock.getExtensionVersion.mockResolvedValue(new Version(0, 0, 1)); - - await expect(sut.init()).rejects.toThrow(/('tensorchord\/pgvecto-rs:pg14-v0\.1\.11')/s); - - sut.maxVectorsVersion = new Version(0, 1, 12); - await expect(sut.init()).rejects.toThrow(/('tensorchord\/pgvecto-rs:pg14-v0\.1\.12')/s); - }); + expect(databaseMock.shouldReindex).toHaveBeenCalledTimes(2); + expect(databaseMock.reindex).not.toHaveBeenCalled(); + expect(databaseMock.runMigrations).toHaveBeenCalledTimes(1); + expect(fatalLog).not.toHaveBeenCalled(); + }, + ); }); }); diff --git a/server/src/domain/database/database.service.ts b/server/src/domain/database/database.service.ts index 5af576a73b..5ea9e1a474 100644 --- a/server/src/domain/database/database.service.ts +++ b/server/src/domain/database/database.service.ts @@ -1,74 +1,56 @@ import { ImmichLogger } from '@app/infra/logger'; import { Inject, Injectable } from '@nestjs/common'; import { QueryFailedError } from 'typeorm'; -import { Version } from '../domain.constant'; -import { DatabaseExtension, IDatabaseRepository } from '../repositories'; +import { Version, VersionType } from '../domain.constant'; +import { + DatabaseExtension, + DatabaseLock, + IDatabaseRepository, + VectorExtension, + VectorIndex, + extName, +} from '../repositories'; @Injectable() export class DatabaseService { private logger = new ImmichLogger(DatabaseService.name); + private vectorExt: VectorExtension; minPostgresVersion = 14; - minVectorsVersion = new Version(0, 1, 1); - maxVectorsVersion = new Version(0, 1, 11); + minVectorsVersion = new Version(0, 2, 0); + vectorsVersionPin = VersionType.MINOR; + minVectorVersion = new Version(0, 5, 0); + vectorVersionPin = VersionType.MAJOR; - constructor(@Inject(IDatabaseRepository) private databaseRepository: IDatabaseRepository) {} + constructor(@Inject(IDatabaseRepository) private databaseRepository: IDatabaseRepository) { + this.vectorExt = this.databaseRepository.getPreferredVectorExtension(); + } async init() { await this.assertPostgresql(); - await this.createVectors(); - await this.assertVectors(); - await this.databaseRepository.runMigrations(); - } + await this.databaseRepository.withLock(DatabaseLock.Migrations, async () => { + await this.createVectorExtension(); + await this.updateVectorExtension(); + await this.assertVectorExtension(); - private async assertVectors() { - const version = await this.databaseRepository.getExtensionVersion(DatabaseExtension.VECTORS); - if (version == null) { - throw new Error('Unexpected: The pgvecto.rs extension is not installed.'); - } + try { + if (await this.databaseRepository.shouldReindex(VectorIndex.CLIP)) { + await this.databaseRepository.reindex(VectorIndex.CLIP); + } - const image = await this.getVectorsImage(); - const suggestion = image ? `, such as with the docker image '${image}'` : ''; + if (await this.databaseRepository.shouldReindex(VectorIndex.FACE)) { + await this.databaseRepository.reindex(VectorIndex.FACE); + } + } catch (error) { + this.logger.warn( + 'Could not run vector reindexing checks. If the extension was updated, please restart the Postgres instance.', + ); + throw error; + } - if (version.isEqual(new Version(0, 0, 0))) { - throw new Error( - `The pgvecto.rs extension version is ${version}, which means it is a nightly release.` + - `Please run 'DROP EXTENSION IF EXISTS vectors' and switch to a release version${suggestion}.`, - ); - } - - if (version.isNewerThan(this.maxVectorsVersion)) { - throw new Error(` - The pgvecto.rs extension version is ${version} instead of ${this.maxVectorsVersion}. - Please run 'DROP EXTENSION IF EXISTS vectors' and switch to ${this.maxVectorsVersion}${suggestion}.`); - } - - if (version.isOlderThan(this.minVectorsVersion)) { - throw new Error(` - The pgvecto.rs extension version is ${version}, which is older than the minimum supported version ${this.minVectorsVersion}. - Please upgrade to this version or later${suggestion}.`); - } - } - - private async createVectors() { - await this.databaseRepository.createExtension(DatabaseExtension.VECTORS).catch(async (error: QueryFailedError) => { - const image = await this.getVectorsImage(); - this.logger.fatal(` - Failed to create pgvecto.rs extension. - If you have not updated your Postgres instance to a docker image that supports pgvecto.rs (such as '${image}'), please do so. - See the v1.91.0 release notes for more info: https://github.com/immich-app/immich/releases/tag/v1.91.0' - `); - throw error; + await this.databaseRepository.runMigrations(); }); } - private async getVectorsImage() { - const { major } = await this.databaseRepository.getPostgresVersion(); - if (![14, 15, 16].includes(major)) { - return null; - } - return `tensorchord/pgvecto-rs:pg${major}-v${this.maxVectorsVersion}`; - } - private async assertPostgresql() { const { major } = await this.databaseRepository.getPostgresVersion(); if (major < this.minPostgresVersion) { @@ -77,4 +59,99 @@ export class DatabaseService { Please upgrade to this version or later.`); } } + + private async createVectorExtension() { + await this.databaseRepository.createExtension(this.vectorExt).catch(async (error: QueryFailedError) => { + const otherExt = + this.vectorExt === DatabaseExtension.VECTORS ? DatabaseExtension.VECTOR : DatabaseExtension.VECTORS; + this.logger.fatal(` + Failed to activate ${extName[this.vectorExt]} extension. + Please ensure the Postgres instance has ${extName[this.vectorExt]} installed. + + If the Postgres instance already has ${extName[this.vectorExt]} installed, Immich may not have the necessary permissions to activate it. + In this case, please run 'CREATE EXTENSION IF NOT EXISTS ${this.vectorExt}' manually as a superuser. + See https://immich.app/docs/guides/database-queries for how to query the database. + + Alternatively, if your Postgres instance has ${extName[otherExt]}, you may use this instead by setting the environment variable 'VECTOR_EXTENSION=${otherExt}'. + Note that switching between the two extensions after a successful startup is not supported. + The exception is if your version of Immich prior to upgrading was 1.90.2 or earlier. + In this case, you may set either extension now, but you will not be able to switch to the other extension following a successful startup. + `); + throw error; + }); + } + + private async updateVectorExtension() { + const [version, availableVersion] = await Promise.all([ + this.databaseRepository.getExtensionVersion(this.vectorExt), + this.databaseRepository.getAvailableExtensionVersion(this.vectorExt), + ]); + if (version == null) { + throw new Error(`Unexpected: ${extName[this.vectorExt]} extension is not installed.`); + } + + if (availableVersion == null) { + return; + } + + const maxVersion = this.vectorExt === DatabaseExtension.VECTOR ? this.vectorVersionPin : this.vectorsVersionPin; + const isNewer = availableVersion.isNewerThan(version); + if (isNewer == null || isNewer > maxVersion) { + return; + } + + try { + this.logger.log(`Updating ${extName[this.vectorExt]} extension to ${availableVersion}`); + const { restartRequired } = await this.databaseRepository.updateVectorExtension(this.vectorExt, availableVersion); + if (restartRequired) { + this.logger.warn(` + The ${extName[this.vectorExt]} extension has been updated to ${availableVersion}. + Please restart the Postgres instance to complete the update.`); + } + } catch (error) { + this.logger.warn(` + The ${extName[this.vectorExt]} extension version is ${version}, but ${availableVersion} is available. + Immich attempted to update the extension, but failed to do so. + This may be because Immich does not have the necessary permissions to update the extension. + + Please run 'ALTER EXTENSION ${this.vectorExt} UPDATE' manually as a superuser. + See https://immich.app/docs/guides/database-queries for how to query the database.`); + this.logger.error(error); + } + } + + private async assertVectorExtension() { + const version = await this.databaseRepository.getExtensionVersion(this.vectorExt); + if (version == null) { + throw new Error(`Unexpected: The ${extName[this.vectorExt]} extension is not installed.`); + } + + if (version.isEqual(new Version(0, 0, 0))) { + throw new Error(` + The ${extName[this.vectorExt]} extension version is ${version}, which means it is a nightly release. + + Please run 'DROP EXTENSION IF EXISTS ${this.vectorExt}' and switch to a release version. + See https://immich.app/docs/guides/database-queries for how to query the database.`); + } + + const minVersion = this.vectorExt === DatabaseExtension.VECTOR ? this.minVectorVersion : this.minVectorsVersion; + const maxVersion = this.vectorExt === DatabaseExtension.VECTOR ? this.vectorVersionPin : this.vectorsVersionPin; + + if (version.isOlderThan(minVersion) || version.isNewerThan(minVersion) > maxVersion) { + const allowedReleaseType = maxVersion === VersionType.MAJOR ? '' : ` ${VersionType[maxVersion].toLowerCase()}`; + const releases = + maxVersion === VersionType.EQUAL + ? minVersion.toString() + : `${minVersion} and later${allowedReleaseType} releases`; + + throw new Error(` + The ${extName[this.vectorExt]} extension version is ${version}, but Immich only supports ${releases}. + + If the Postgres instance already has a compatible version installed, Immich may not have the necessary permissions to activate it. + In this case, please run 'ALTER EXTENSION UPDATE ${this.vectorExt}' manually as a superuser. + See https://immich.app/docs/guides/database-queries for how to query the database. + + Otherwise, please update the version of ${extName[this.vectorExt]} in the Postgres instance to a compatible version.`); + } + } } diff --git a/server/src/domain/domain.config.ts b/server/src/domain/domain.config.ts index 3a106bad2b..ed1283ec2f 100644 --- a/server/src/domain/domain.config.ts +++ b/server/src/domain/domain.config.ts @@ -24,5 +24,6 @@ export const immichAppConfig: ConfigModuleOptions = { MACHINE_LEARNING_PORT: Joi.number().optional(), MICROSERVICES_PORT: Joi.number().optional(), SERVER_PORT: Joi.number().optional(), + VECTOR_EXTENSION: Joi.string().optional().valid('pgvector', 'pgvecto.rs').default('pgvecto.rs'), }), }; diff --git a/server/src/domain/domain.constant.spec.ts b/server/src/domain/domain.constant.spec.ts index 4ec4b1124c..154128a1c2 100644 --- a/server/src/domain/domain.constant.spec.ts +++ b/server/src/domain/domain.constant.spec.ts @@ -1,4 +1,4 @@ -import { Version, mimeTypes } from './domain.constant'; +import { Version, VersionType, mimeTypes } from './domain.constant'; describe('mimeTypes', () => { for (const { mimetype, extension } of [ @@ -196,45 +196,37 @@ describe('mimeTypes', () => { }); }); -describe('ServerVersion', () => { +describe('Version', () => { const tests = [ - { this: new Version(0, 0, 1), other: new Version(0, 0, 0), expected: 1 }, - { this: new Version(0, 1, 0), other: new Version(0, 0, 0), expected: 1 }, - { this: new Version(1, 0, 0), other: new Version(0, 0, 0), expected: 1 }, - { this: new Version(0, 0, 0), other: new Version(0, 0, 1), expected: -1 }, - { this: new Version(0, 0, 0), other: new Version(0, 1, 0), expected: -1 }, - { this: new Version(0, 0, 0), other: new Version(1, 0, 0), expected: -1 }, - { this: new Version(0, 0, 0), other: new Version(0, 0, 0), expected: 0 }, - { this: new Version(0, 0, 1), other: new Version(0, 0, 1), expected: 0 }, - { this: new Version(0, 1, 0), other: new Version(0, 1, 0), expected: 0 }, - { this: new Version(1, 0, 0), other: new Version(1, 0, 0), expected: 0 }, - { this: new Version(1, 0), other: new Version(1, 0, 0), expected: 0 }, - { this: new Version(1, 0), other: new Version(1, 0, 1), expected: -1 }, - { this: new Version(1, 1), other: new Version(1, 0, 1), expected: 1 }, - { this: new Version(1), other: new Version(1, 0, 0), expected: 0 }, - { this: new Version(1), other: new Version(1, 0, 1), expected: -1 }, + { this: new Version(0, 0, 1), other: new Version(0, 0, 0), compare: 1, type: VersionType.PATCH }, + { this: new Version(0, 1, 0), other: new Version(0, 0, 0), compare: 1, type: VersionType.MINOR }, + { this: new Version(1, 0, 0), other: new Version(0, 0, 0), compare: 1, type: VersionType.MAJOR }, + { this: new Version(0, 0, 0), other: new Version(0, 0, 1), compare: -1, type: VersionType.PATCH }, + { this: new Version(0, 0, 0), other: new Version(0, 1, 0), compare: -1, type: VersionType.MINOR }, + { this: new Version(0, 0, 0), other: new Version(1, 0, 0), compare: -1, type: VersionType.MAJOR }, + { this: new Version(0, 0, 0), other: new Version(0, 0, 0), compare: 0, type: VersionType.EQUAL }, + { this: new Version(0, 0, 1), other: new Version(0, 0, 1), compare: 0, type: VersionType.EQUAL }, + { this: new Version(0, 1, 0), other: new Version(0, 1, 0), compare: 0, type: VersionType.EQUAL }, + { this: new Version(1, 0, 0), other: new Version(1, 0, 0), compare: 0, type: VersionType.EQUAL }, + { this: new Version(1, 0), other: new Version(1, 0, 0), compare: 0, type: VersionType.EQUAL }, + { this: new Version(1, 0), other: new Version(1, 0, 1), compare: -1, type: VersionType.PATCH }, + { this: new Version(1, 1), other: new Version(1, 0, 1), compare: 1, type: VersionType.MINOR }, + { this: new Version(1), other: new Version(1, 0, 0), compare: 0, type: VersionType.EQUAL }, + { this: new Version(1), other: new Version(1, 0, 1), compare: -1, type: VersionType.PATCH }, ]; - describe('compare', () => { - for (const { this: thisVersion, other: otherVersion, expected } of tests) { - it(`should return ${expected} when comparing ${thisVersion} to ${otherVersion}`, () => { - expect(thisVersion.compare(otherVersion)).toEqual(expected); - }); - } - }); - describe('isOlderThan', () => { - for (const { this: thisVersion, other: otherVersion, expected } of tests) { - const bool = expected < 0; - it(`should return ${bool} when comparing ${thisVersion} to ${otherVersion}`, () => { - expect(thisVersion.isOlderThan(otherVersion)).toEqual(bool); + for (const { this: thisVersion, other: otherVersion, compare, type } of tests) { + const expected = compare < 0 ? type : VersionType.EQUAL; + it(`should return '${expected}' when comparing ${thisVersion} to ${otherVersion}`, () => { + expect(thisVersion.isOlderThan(otherVersion)).toEqual(expected); }); } }); describe('isEqual', () => { - for (const { this: thisVersion, other: otherVersion, expected } of tests) { - const bool = expected === 0; + for (const { this: thisVersion, other: otherVersion, compare } of tests) { + const bool = compare === 0; it(`should return ${bool} when comparing ${thisVersion} to ${otherVersion}`, () => { expect(thisVersion.isEqual(otherVersion)).toEqual(bool); }); @@ -242,10 +234,10 @@ describe('ServerVersion', () => { }); describe('isNewerThan', () => { - for (const { this: thisVersion, other: otherVersion, expected } of tests) { - const bool = expected > 0; - it(`should return ${bool} when comparing ${thisVersion} to ${otherVersion}`, () => { - expect(thisVersion.isNewerThan(otherVersion)).toEqual(bool); + for (const { this: thisVersion, other: otherVersion, compare, type } of tests) { + const expected = compare > 0 ? type : VersionType.EQUAL; + it(`should return ${expected} when comparing ${thisVersion} to ${otherVersion}`, () => { + expect(thisVersion.isNewerThan(otherVersion)).toEqual(expected); }); } }); diff --git a/server/src/domain/domain.constant.ts b/server/src/domain/domain.constant.ts index 227595e04f..4e7c4d5524 100644 --- a/server/src/domain/domain.constant.ts +++ b/server/src/domain/domain.constant.ts @@ -12,11 +12,20 @@ export interface IVersion { patch: number; } +export enum VersionType { + EQUAL = 0, + PATCH = 1, + MINOR = 2, + MAJOR = 3, +} + export class Version implements IVersion { + public readonly types = ['major', 'minor', 'patch'] as const; + constructor( - public readonly major: number, - public readonly minor: number = 0, - public readonly patch: number = 0, + public major: number, + public minor: number = 0, + public patch: number = 0, ) {} toString() { @@ -39,27 +48,30 @@ export class Version implements IVersion { } } - compare(version: Version): number { - for (const key of ['major', 'minor', 'patch'] as const) { + private compare(version: Version): [number, VersionType] { + for (const [i, key] of this.types.entries()) { const diff = this[key] - version[key]; if (diff !== 0) { - return diff > 0 ? 1 : -1; + return [diff > 0 ? 1 : -1, (VersionType.MAJOR - i) as VersionType]; } } - return 0; + return [0, VersionType.EQUAL]; } - isOlderThan(version: Version): boolean { - return this.compare(version) < 0; + isOlderThan(version: Version): VersionType { + const [bool, type] = this.compare(version); + return bool < 0 ? type : VersionType.EQUAL; } isEqual(version: Version): boolean { - return this.compare(version) === 0; + const [bool] = this.compare(version); + return bool === 0; } - isNewerThan(version: Version): boolean { - return this.compare(version) > 0; + isNewerThan(version: Version): VersionType { + const [bool, type] = this.compare(version); + return bool > 0 ? type : VersionType.EQUAL; } } diff --git a/server/src/domain/repositories/database.repository.ts b/server/src/domain/repositories/database.repository.ts index 07d0afca6b..d32939fe61 100644 --- a/server/src/domain/repositories/database.repository.ts +++ b/server/src/domain/repositories/database.repository.ts @@ -3,21 +3,47 @@ import { Version } from '../domain.constant'; export enum DatabaseExtension { CUBE = 'cube', EARTH_DISTANCE = 'earthdistance', + VECTOR = 'vector', VECTORS = 'vectors', } +export type VectorExtension = DatabaseExtension.VECTOR | DatabaseExtension.VECTORS; + +export enum VectorIndex { + CLIP = 'clip_index', + FACE = 'face_index', +} + export enum DatabaseLock { GeodataImport = 100, + Migrations = 200, StorageTemplateMigration = 420, CLIPDimSize = 512, } +export const extName: Record = { + cube: 'cube', + earthdistance: 'earthdistance', + vector: 'pgvector', + vectors: 'pgvecto.rs', +} as const; + +export interface VectorUpdateResult { + restartRequired: boolean; +} + export const IDatabaseRepository = 'IDatabaseRepository'; export interface IDatabaseRepository { getExtensionVersion(extensionName: string): Promise; + getAvailableExtensionVersion(extension: DatabaseExtension): Promise; + getPreferredVectorExtension(): VectorExtension; getPostgresVersion(): Promise; createExtension(extension: DatabaseExtension): Promise; + updateExtension(extension: DatabaseExtension, version?: Version): Promise; + updateVectorExtension(extension: VectorExtension, version?: Version): Promise; + reindex(index: VectorIndex): Promise; + shouldReindex(name: VectorIndex): Promise; runMigrations(options?: { transaction?: 'all' | 'none' | 'each' }): Promise; withLock(lock: DatabaseLock, callback: () => Promise): Promise; isBusy(lock: DatabaseLock): boolean; diff --git a/server/src/domain/repositories/smart-info.repository.ts b/server/src/domain/repositories/smart-info.repository.ts index 7b82e9d744..acb907bc8f 100644 --- a/server/src/domain/repositories/smart-info.repository.ts +++ b/server/src/domain/repositories/smart-info.repository.ts @@ -7,7 +7,7 @@ export type Embedding = number[]; export interface EmbeddingSearch { userIds: string[]; embedding: Embedding; - numResults?: number; + numResults: number; withArchived?: boolean; } diff --git a/server/src/infra/database.config.ts b/server/src/infra/database.config.ts index 9e6cccd198..93926e51cf 100644 --- a/server/src/infra/database.config.ts +++ b/server/src/infra/database.config.ts @@ -1,3 +1,4 @@ +import { DatabaseExtension } from '@app/domain/repositories/database.repository'; import { DataSource } from 'typeorm'; import { PostgresConnectionOptions } from 'typeorm/driver/postgres/PostgresConnectionOptions.js'; @@ -27,3 +28,6 @@ export const databaseConfig: PostgresConnectionOptions = { // this export is used by TypeORM commands in package.json#scripts export const dataSource = new DataSource(databaseConfig); + +export const vectorExt = + process.env.VECTOR_EXTENSION === 'pgvector' ? DatabaseExtension.VECTOR : DatabaseExtension.VECTORS; diff --git a/server/src/infra/logger.ts b/server/src/infra/logger.ts index 183ffb492f..8de149c409 100644 --- a/server/src/infra/logger.ts +++ b/server/src/infra/logger.ts @@ -5,7 +5,7 @@ import { LogLevel } from './entities'; const LOG_LEVELS = [LogLevel.VERBOSE, LogLevel.DEBUG, LogLevel.LOG, LogLevel.WARN, LogLevel.ERROR, LogLevel.FATAL]; export class ImmichLogger extends ConsoleLogger { - private static logLevels: LogLevel[] = [LogLevel.WARN, LogLevel.ERROR, LogLevel.FATAL]; + private static logLevels: LogLevel[] = [LogLevel.LOG, LogLevel.WARN, LogLevel.ERROR, LogLevel.FATAL]; constructor(context: string) { super(context); diff --git a/server/src/infra/migrations/1700713871511-UsePgVectors.ts b/server/src/infra/migrations/1700713871511-UsePgVectors.ts index a952f1646d..008d5eadc8 100644 --- a/server/src/infra/migrations/1700713871511-UsePgVectors.ts +++ b/server/src/infra/migrations/1700713871511-UsePgVectors.ts @@ -1,11 +1,13 @@ import { getCLIPModelInfo } from '@app/domain/smart-info/smart-info.constant'; import { MigrationInterface, QueryRunner } from 'typeorm'; +import { vectorExt } from '@app/infra/database.config'; export class UsePgVectors1700713871511 implements MigrationInterface { name = 'UsePgVectors1700713871511'; public async up(queryRunner: QueryRunner): Promise { - await queryRunner.query(`CREATE EXTENSION IF NOT EXISTS vectors`); + await queryRunner.query(`SET search_path TO "$user", public, vectors`); + await queryRunner.query(`CREATE EXTENSION IF NOT EXISTS ${vectorExt}`); const faceDimQuery = await queryRunner.query(` SELECT CARDINALITY(embedding::real[]) as dimsize FROM asset_faces diff --git a/server/src/infra/migrations/1700713994428-AddCLIPEmbeddingIndex.ts b/server/src/infra/migrations/1700713994428-AddCLIPEmbeddingIndex.ts index 7a1a1144d6..c3716cc191 100644 --- a/server/src/infra/migrations/1700713994428-AddCLIPEmbeddingIndex.ts +++ b/server/src/infra/migrations/1700713994428-AddCLIPEmbeddingIndex.ts @@ -1,16 +1,20 @@ import { MigrationInterface, QueryRunner } from 'typeorm'; +import { vectorExt } from '../database.config'; +import { DatabaseExtension } from '@app/domain/repositories/database.repository'; export class AddCLIPEmbeddingIndex1700713994428 implements MigrationInterface { name = 'AddCLIPEmbeddingIndex1700713994428'; public async up(queryRunner: QueryRunner): Promise { + if (vectorExt === DatabaseExtension.VECTORS) { + await queryRunner.query(`SET vectors.pgvector_compatibility=on`); + } + await queryRunner.query(`SET search_path TO "$user", public, vectors`); + await queryRunner.query(` CREATE INDEX IF NOT EXISTS clip_index ON smart_search - USING vectors (embedding cosine_ops) WITH (options = $$ - [indexing.hnsw] - m = 16 - ef_construction = 300 - $$);`); + USING hnsw (embedding vector_cosine_ops) + WITH (ef_construction = 300, m = 16)`); } public async down(queryRunner: QueryRunner): Promise { diff --git a/server/src/infra/migrations/1700714033632-AddFaceEmbeddingIndex.ts b/server/src/infra/migrations/1700714033632-AddFaceEmbeddingIndex.ts index 0ac7b0cd4c..066303530a 100644 --- a/server/src/infra/migrations/1700714033632-AddFaceEmbeddingIndex.ts +++ b/server/src/infra/migrations/1700714033632-AddFaceEmbeddingIndex.ts @@ -1,16 +1,20 @@ import { MigrationInterface, QueryRunner } from 'typeorm'; +import { vectorExt } from '../database.config'; +import { DatabaseExtension } from '@app/domain/repositories/database.repository'; export class AddFaceEmbeddingIndex1700714033632 implements MigrationInterface { name = 'AddFaceEmbeddingIndex1700714033632'; public async up(queryRunner: QueryRunner): Promise { + if (vectorExt === DatabaseExtension.VECTORS) { + await queryRunner.query(`SET vectors.pgvector_compatibility=on`); + } + await queryRunner.query(`SET search_path TO "$user", public, vectors`); + await queryRunner.query(` CREATE INDEX IF NOT EXISTS face_index ON asset_faces - USING vectors (embedding cosine_ops) WITH (options = $$ - [indexing.hnsw] - m = 16 - ef_construction = 300 - $$);`); + USING hnsw (embedding vector_cosine_ops) + WITH (ef_construction = 300, m = 16)`); } public async down(queryRunner: QueryRunner): Promise { diff --git a/server/src/infra/migrations/1707000751533-AddVectorsToSearchPath.ts b/server/src/infra/migrations/1707000751533-AddVectorsToSearchPath.ts new file mode 100644 index 0000000000..e83e4b4fb0 --- /dev/null +++ b/server/src/infra/migrations/1707000751533-AddVectorsToSearchPath.ts @@ -0,0 +1,14 @@ +import { MigrationInterface, QueryRunner } from 'typeorm'; + +export class AddVectorsToSearchPath1707000751533 implements MigrationInterface { + public async up(queryRunner: QueryRunner): Promise { + const res = await queryRunner.query(`SELECT current_database() as db`); + const databaseName = res[0]['db']; + await queryRunner.query(`ALTER DATABASE ${databaseName} SET search_path TO "$user", public, vectors`); + } + + public async down(queryRunner: QueryRunner): Promise { + const databaseName = await queryRunner.query(`SELECT current_database()`); + await queryRunner.query(`ALTER DATABASE ${databaseName} SET search_path TO "$user", public`); + } +} diff --git a/server/src/infra/repositories/database.repository.ts b/server/src/infra/repositories/database.repository.ts index af595057e2..b0e4623af5 100644 --- a/server/src/infra/repositories/database.repository.ts +++ b/server/src/infra/repositories/database.repository.ts @@ -1,21 +1,60 @@ -import { DatabaseExtension, DatabaseLock, IDatabaseRepository, Version } from '@app/domain'; +import { + DatabaseExtension, + DatabaseLock, + IDatabaseRepository, + VectorExtension, + VectorIndex, + VectorUpdateResult, + Version, + VersionType, + extName, +} from '@app/domain'; +import { vectorExt } from '@app/infra/database.config'; import { Injectable } from '@nestjs/common'; import { InjectDataSource } from '@nestjs/typeorm'; import AsyncLock from 'async-lock'; -import { DataSource, QueryRunner } from 'typeorm'; +import { DataSource, EntityManager, QueryRunner } from 'typeorm'; +import { isValidInteger } from '../infra.utils'; +import { ImmichLogger } from '../logger'; @Injectable() export class DatabaseRepository implements IDatabaseRepository { + private logger = new ImmichLogger(DatabaseRepository.name); readonly asyncLock = new AsyncLock(); constructor(@InjectDataSource() private dataSource: DataSource) {} async getExtensionVersion(extension: DatabaseExtension): Promise { const res = await this.dataSource.query(`SELECT extversion FROM pg_extension WHERE extname = $1`, [extension]); - const version = res[0]?.['extversion']; + const extVersion = res[0]?.['extversion']; + if (extVersion == null) { + return null; + } + + const version = Version.fromString(extVersion); + if (version.isEqual(new Version(0, 1, 1))) { + return new Version(0, 1, 11); + } + + return version; + } + + async getAvailableExtensionVersion(extension: DatabaseExtension): Promise { + const res = await this.dataSource.query( + ` + SELECT version FROM pg_available_extension_versions + WHERE name = $1 AND installed = false + ORDER BY version DESC`, + [extension], + ); + const version = res[0]?.['version']; return version == null ? null : Version.fromString(version); } + getPreferredVectorExtension(): VectorExtension { + return vectorExt; + } + async getPostgresVersion(): Promise { const res = await this.dataSource.query(`SHOW server_version`); return Version.fromString(res[0]['server_version']); @@ -25,6 +64,129 @@ export class DatabaseRepository implements IDatabaseRepository { await this.dataSource.query(`CREATE EXTENSION IF NOT EXISTS ${extension}`); } + async updateExtension(extension: DatabaseExtension, version?: Version): Promise { + await this.dataSource.query(`ALTER EXTENSION ${extension} UPDATE${version ? ` TO '${version}'` : ''}`); + } + + async updateVectorExtension(extension: VectorExtension, version?: Version): Promise { + const curVersion = await this.getExtensionVersion(extension); + if (!curVersion) { + throw new Error(`${extName[extension]} extension is not installed`); + } + + const minorOrMajor = version && curVersion.isOlderThan(version) >= VersionType.MINOR; + const isVectors = extension === DatabaseExtension.VECTORS; + let restartRequired = false; + await this.dataSource.manager.transaction(async (manager) => { + await this.setSearchPath(manager); + if (minorOrMajor && isVectors) { + await this.updateVectorsSchema(manager, curVersion); + } + + await manager.query(`ALTER EXTENSION ${extension} UPDATE${version ? ` TO '${version}'` : ''}`); + + if (!minorOrMajor) { + return; + } + + if (isVectors) { + await manager.query('SELECT pgvectors_upgrade()'); + restartRequired = true; + } else { + await this.reindex(VectorIndex.CLIP); + await this.reindex(VectorIndex.FACE); + } + }); + + return { restartRequired }; + } + + async reindex(index: VectorIndex): Promise { + try { + await this.dataSource.query(`REINDEX INDEX ${index}`); + } catch (error) { + if (vectorExt === DatabaseExtension.VECTORS) { + this.logger.warn(`Could not reindex index ${index}. Attempting to auto-fix.`); + const table = index === VectorIndex.CLIP ? 'smart_search' : 'asset_faces'; + const dimSize = await this.getDimSize(table); + await this.dataSource.manager.transaction(async (manager) => { + await this.setSearchPath(manager); + await manager.query(`DROP INDEX IF EXISTS ${index}`); + await manager.query(`ALTER TABLE ${table} ALTER COLUMN embedding SET DATA TYPE real[]`); + await manager.query(`ALTER TABLE ${table} ALTER COLUMN embedding SET DATA TYPE vector(${dimSize})`); + await manager.query(`SET vectors.pgvector_compatibility=on`); + await manager.query(` + CREATE INDEX IF NOT EXISTS ${index} ON ${table} + USING hnsw (embedding vector_cosine_ops) + WITH (ef_construction = 300, m = 16)`); + }); + } else { + throw error; + } + } + } + + async shouldReindex(name: VectorIndex): Promise { + if (vectorExt !== DatabaseExtension.VECTORS) { + return false; + } + + try { + const res = await this.dataSource.query( + ` + SELECT idx_status + FROM pg_vector_index_stat + WHERE indexname = $1`, + [name], + ); + return res[0]?.['idx_status'] === 'UPGRADE'; + } catch (error) { + const message: string = (error as any).message; + if (message.includes('index is not existing')) { + return true; + } else if (message.includes('relation "pg_vector_index_stat" does not exist')) { + return false; + } + throw error; + } + } + + private async setSearchPath(manager: EntityManager): Promise { + await manager.query(`SET search_path TO "$user", public, vectors`); + } + + private async updateVectorsSchema(manager: EntityManager, curVersion: Version): Promise { + await manager.query('CREATE SCHEMA IF NOT EXISTS vectors'); + await manager.query(`UPDATE pg_catalog.pg_extension SET extversion = $1 WHERE extname = $2`, [ + curVersion.toString(), + DatabaseExtension.VECTORS, + ]); + await manager.query('UPDATE pg_catalog.pg_extension SET extrelocatable = true WHERE extname = $1', [ + DatabaseExtension.VECTORS, + ]); + await manager.query('ALTER EXTENSION vectors SET SCHEMA vectors'); + await manager.query('UPDATE pg_catalog.pg_extension SET extrelocatable = false WHERE extname = $1', [ + DatabaseExtension.VECTORS, + ]); + } + + private async getDimSize(table: string, column = 'embedding'): Promise { + const res = await this.dataSource.query(` + SELECT atttypmod as dimsize + FROM pg_attribute f + JOIN pg_class c ON c.oid = f.attrelid + WHERE c.relkind = 'r'::char + AND f.attnum > 0 + AND c.relname = '${table}' + AND f.attname = '${column}'`); + + const dimSize = res[0]['dimsize']; + if (!isValidInteger(dimSize, { min: 1, max: 2 ** 16 })) { + throw new Error(`Could not retrieve dimension size`); + } + return dimSize; + } + async runMigrations(options?: { transaction?: 'all' | 'none' | 'each' }): Promise { await this.dataSource.runMigrations(options); } diff --git a/server/src/infra/repositories/smart-info.repository.ts b/server/src/infra/repositories/smart-info.repository.ts index ab43ff6f91..f74fd4232d 100644 --- a/server/src/infra/repositories/smart-info.repository.ts +++ b/server/src/infra/repositories/smart-info.repository.ts @@ -1,10 +1,18 @@ -import { Embedding, EmbeddingSearch, FaceEmbeddingSearch, FaceSearchResult, ISmartInfoRepository } from '@app/domain'; +import { + DatabaseExtension, + Embedding, + EmbeddingSearch, + FaceEmbeddingSearch, + FaceSearchResult, + ISmartInfoRepository, +} from '@app/domain'; import { getCLIPModelInfo } from '@app/domain/smart-info/smart-info.constant'; import { AssetEntity, AssetFaceEntity, SmartInfoEntity, SmartSearchEntity } from '@app/infra/entities'; import { ImmichLogger } from '@app/infra/logger'; import { Injectable } from '@nestjs/common'; import { InjectRepository } from '@nestjs/typeorm'; import { Repository } from 'typeorm'; +import { vectorExt } from '../database.config'; import { DummyValue, GenerateSql } from '../infra.util'; import { asVector, isValidInteger } from '../infra.utils'; @@ -44,16 +52,20 @@ export class SmartInfoRepository implements ISmartInfoRepository { params: [{ userIds: [DummyValue.UUID], embedding: Array.from({ length: 512 }, Math.random), numResults: 100 }], }) async searchCLIP({ userIds, embedding, numResults, withArchived }: EmbeddingSearch): Promise { + if (!isValidInteger(numResults, { min: 1 })) { + throw new Error(`Invalid value for 'numResults': ${numResults}`); + } + + // setting this too low messes with prefilter recall + numResults = Math.max(numResults, 64); + let results: AssetEntity[] = []; await this.assetRepository.manager.transaction(async (manager) => { - await manager.query(`SET LOCAL vectors.enable_prefilter = on`); - - let query = manager + const query = manager .createQueryBuilder(AssetEntity, 'a') .innerJoin('a.smartSearch', 's') .leftJoinAndSelect('a.exifInfo', 'e') .where('a.ownerId IN (:...userIds )') - .orderBy('s.embedding <=> :embedding') .setParameters({ userIds, embedding: asVector(embedding) }); @@ -61,15 +73,9 @@ export class SmartInfoRepository implements ISmartInfoRepository { query.andWhere('a.isArchived = false'); } query.andWhere('a.isVisible = true').andWhere('a.fileCreatedAt < NOW()'); + query.limit(numResults); - if (numResults) { - if (!isValidInteger(numResults, { min: 1 })) { - throw new Error(`Invalid value for 'numResults': ${numResults}`); - } - query = query.limit(numResults); - await manager.query(`SET LOCAL vectors.k = '${numResults}'`); - } - + await manager.query(this.getRuntimeConfig(numResults)); results = await query.getMany(); }); @@ -93,36 +99,34 @@ export class SmartInfoRepository implements ISmartInfoRepository { maxDistance, hasPerson, }: FaceEmbeddingSearch): Promise { + if (!isValidInteger(numResults, { min: 1 })) { + throw new Error(`Invalid value for 'numResults': ${numResults}`); + } + + // setting this too low messes with prefilter recall + numResults = Math.max(numResults, 64); + let results: Array = []; await this.assetRepository.manager.transaction(async (manager) => { - await manager.query(`SET LOCAL vectors.enable_prefilter = on`); - let cte = manager + const cte = manager .createQueryBuilder(AssetFaceEntity, 'faces') - .select('1 + (faces.embedding <=> :embedding)', 'distance') + .select('faces.embedding <=> :embedding', 'distance') .innerJoin('faces.asset', 'asset') .where('asset.ownerId IN (:...userIds )') - .orderBy('1 + (faces.embedding <=> :embedding)') + .orderBy('faces.embedding <=> :embedding') .setParameters({ userIds, embedding: asVector(embedding) }); - if (numResults) { - if (!isValidInteger(numResults, { min: 1 })) { - throw new Error(`Invalid value for 'numResults': ${numResults}`); - } - cte = cte.limit(numResults); - if (numResults > 64) { - // setting k too low messes with prefilter recall - await manager.query(`SET LOCAL vectors.k = '${numResults}'`); - } - } + cte.limit(numResults); if (hasPerson) { - cte = cte.andWhere('faces."personId" IS NOT NULL'); + cte.andWhere('faces."personId" IS NOT NULL'); } for (const col of this.faceColumns) { cte.addSelect(`faces.${col}`, col); } + await manager.query(this.getRuntimeConfig(numResults)); results = await manager .createQueryBuilder() .select('res.*') @@ -167,6 +171,9 @@ export class SmartInfoRepository implements ISmartInfoRepository { this.logger.log(`Updating database CLIP dimension size to ${dimSize}.`); await this.smartSearchRepository.manager.transaction(async (manager) => { + if (vectorExt === DatabaseExtension.VECTORS) { + await manager.query(`SET vectors.pgvector_compatibility=on`); + } await manager.query(`DROP TABLE smart_search`); await manager.query(` @@ -175,12 +182,9 @@ export class SmartInfoRepository implements ISmartInfoRepository { embedding vector(${dimSize}) NOT NULL )`); await manager.query(` - CREATE INDEX clip_index ON smart_search - USING vectors (embedding cosine_ops) WITH (options = $$ - [indexing.hnsw] - m = 16 - ef_construction = 300 - $$)`); + CREATE INDEX IF NOT EXISTS clip_index ON smart_search + USING hnsw (embedding vector_cosine_ops) + WITH (ef_construction = 300, m = 16)`); }); this.logger.log(`Successfully updated database CLIP dimension size from ${currentDimSize} to ${dimSize}.`); @@ -202,4 +206,17 @@ export class SmartInfoRepository implements ISmartInfoRepository { } return dimSize; } + + private getRuntimeConfig(numResults?: number): string { + if (vectorExt === DatabaseExtension.VECTOR) { + return 'SET LOCAL hnsw.ef_search = 1000;'; // mitigate post-filter recall + } + + let runtimeConfig = 'SET LOCAL vectors.enable_prefilter=on; SET LOCAL vectors.search_mode=vbase;'; + if (numResults) { + runtimeConfig += ` SET LOCAL vectors.hnsw_ef_search = ${numResults};`; + } + + return runtimeConfig; + } } diff --git a/server/src/infra/sql/smart.info.repository.sql b/server/src/infra/sql/smart.info.repository.sql index afb120bade..3151aede73 100644 --- a/server/src/infra/sql/smart.info.repository.sql +++ b/server/src/infra/sql/smart.info.repository.sql @@ -3,9 +3,13 @@ -- SmartInfoRepository.searchCLIP START TRANSACTION SET - LOCAL vectors.enable_prefilter = on + LOCAL vectors.enable_prefilter = on; + SET - LOCAL vectors.k = '100' + LOCAL vectors.search_mode = vbase; + +SET + LOCAL vectors.hnsw_ef_search = 100; SELECT "a"."id" AS "a_id", "a"."deviceAssetId" AS "a_deviceAssetId", @@ -85,9 +89,13 @@ COMMIT -- SmartInfoRepository.searchFaces START TRANSACTION SET - LOCAL vectors.enable_prefilter = on + LOCAL vectors.enable_prefilter = on; + SET - LOCAL vectors.k = '100' + LOCAL vectors.search_mode = vbase; + +SET + LOCAL vectors.hnsw_ef_search = 100; WITH "cte" AS ( SELECT @@ -100,7 +108,7 @@ WITH "faces"."boundingBoxY1" AS "boundingBoxY1", "faces"."boundingBoxX2" AS "boundingBoxX2", "faces"."boundingBoxY2" AS "boundingBoxY2", - 1 + ("faces"."embedding" <= > $1) AS "distance" + "faces"."embedding" <= > $1 AS "distance" FROM "asset_faces" "faces" INNER JOIN "assets" "asset" ON "asset"."id" = "faces"."assetId" @@ -108,7 +116,7 @@ WITH WHERE "asset"."ownerId" IN ($2) ORDER BY - 1 + ("faces"."embedding" <= > $1) ASC + "faces"."embedding" <= > $1 ASC LIMIT 100 ) diff --git a/server/test/repositories/database.repository.mock.ts b/server/test/repositories/database.repository.mock.ts index f34e6b06b5..f5a4d39a67 100644 --- a/server/test/repositories/database.repository.mock.ts +++ b/server/test/repositories/database.repository.mock.ts @@ -3,8 +3,14 @@ import { IDatabaseRepository, Version } from '@app/domain'; export const newDatabaseRepositoryMock = (): jest.Mocked => { return { getExtensionVersion: jest.fn(), + getAvailableExtensionVersion: jest.fn(), + getPreferredVectorExtension: jest.fn(), getPostgresVersion: jest.fn().mockResolvedValue(new Version(14, 0, 0)), createExtension: jest.fn().mockImplementation(() => Promise.resolve()), + updateExtension: jest.fn(), + updateVectorExtension: jest.fn(), + reindex: jest.fn(), + shouldReindex: jest.fn(), runMigrations: jest.fn(), withLock: jest.fn().mockImplementation((_, function_: () => Promise) => function_()), isBusy: jest.fn(),