1
0
Fork 0
mirror of https://github.com/immich-app/immich.git synced 2025-01-17 01:06:46 +01:00

feat(server)!: pgvecto.rs 0.2 and pgvector compatibility (#6785)

* basic changes

update version check

set ef_search for clip

* pgvector compatibility

Revert "pgvector compatibility"

This reverts commit 2b66a52aa4097dd27da58138c5288fd87cb9b24a.

pgvector compatibility: minimal edition

pgvector startup check

* update extension at startup

* wording

shortened vector extension variable name

* nightly docker

* fixed version checks

* update tests

add tests for updating extension

remove unnecessary check

* simplify `getRuntimeConfig`

* wording

* reindex on minor version update

* 0.2 upgrade testing

update prod compose

* acquire lock for init

* wip vector down on shutdown

* use upgrade helper

* update image tag

* refine restart check

check error message

* test reindex

testing

upstream fix

formatting

fixed reindexing

* use enum in signature

* fix tests

remove unused code

* add reindexing tests

* update to official 0.2

remove alpha from version name

* add warning test if restart required

* update test image to 0.2.0

* linting and test cleanup

* formatting

* update sql

* wording

* handle setting search path for new and existing databases

* handle new db in reindex check

* fix post-update reindexing

* get dim size

* formatting

* use vbase

* handle different db name

* update sql

* linting

* fix suggested env
This commit is contained in:
Mert 2024-02-06 21:46:38 -05:00 committed by GitHub
parent b2775c445a
commit 56b0643890
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
25 changed files with 650 additions and 246 deletions

View file

@ -275,7 +275,7 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
services: services:
postgres: postgres:
image: tensorchord/pgvecto-rs:pg14-v0.1.11@sha256:0335a1a22f8c5dd1b697f14f079934f5152eaaa216c09b61e293be285491f8ee image: tensorchord/pgvecto-rs:pg14-v0.2.0
env: env:
POSTGRES_PASSWORD: postgres POSTGRES_PASSWORD: postgres
POSTGRES_USER: postgres POSTGRES_USER: postgres

View file

@ -25,7 +25,7 @@ export default async () => {
if (process.env.DB_HOSTNAME === undefined) { if (process.env.DB_HOSTNAME === undefined) {
// DB hostname not set which likely means we're not running e2e through docker compose. Start a local postgres container. // DB hostname not set which likely means we're not running e2e through docker compose. Start a local postgres container.
const pg = await new PostgreSqlContainer('tensorchord/pgvecto-rs:pg14-v0.1.11') const pg = await new PostgreSqlContainer('tensorchord/pgvecto-rs:pg14-v0.2.0')
.withExposedPorts(5432) .withExposedPorts(5432)
.withDatabase('immich') .withDatabase('immich')
.withUsername('postgres') .withUsername('postgres')

View file

@ -103,7 +103,7 @@ services:
database: database:
container_name: immich_postgres container_name: immich_postgres
image: tensorchord/pgvecto-rs:pg14-v0.1.11@sha256:0335a1a22f8c5dd1b697f14f079934f5152eaaa216c09b61e293be285491f8ee image: tensorchord/pgvecto-rs:pg14-v0.2.0
env_file: env_file:
- .env - .env
environment: environment:

View file

@ -61,7 +61,7 @@ services:
database: database:
container_name: immich_postgres container_name: immich_postgres
image: tensorchord/pgvecto-rs:pg14-v0.1.11@sha256:0335a1a22f8c5dd1b697f14f079934f5152eaaa216c09b61e293be285491f8ee image: tensorchord/pgvecto-rs:pg14-v0.2.0
env_file: env_file:
- .env - .env
environment: environment:
@ -70,7 +70,8 @@ services:
POSTGRES_DB: ${DB_DATABASE_NAME} POSTGRES_DB: ${DB_DATABASE_NAME}
volumes: volumes:
- ${UPLOAD_LOCATION}/postgres:/var/lib/postgresql/data - ${UPLOAD_LOCATION}/postgres:/var/lib/postgresql/data
restart: always ports:
- 5432:5432
volumes: volumes:
model-cache: model-cache:

View file

@ -65,7 +65,7 @@ services:
database: database:
container_name: immich_postgres container_name: immich_postgres
image: tensorchord/pgvecto-rs:pg14-v0.1.11@sha256:0335a1a22f8c5dd1b697f14f079934f5152eaaa216c09b61e293be285491f8ee image: tensorchord/pgvecto-rs:pg14-v0.2.0
env_file: env_file:
- .env - .env
environment: environment:

View file

@ -1,7 +1,7 @@
import { PostgreSqlContainer } from '@testcontainers/postgresql'; import { PostgreSqlContainer } from '@testcontainers/postgresql';
export default async () => { export default async () => {
const pg = await new PostgreSqlContainer('tensorchord/pgvecto-rs:pg14-v0.1.11') const pg = await new PostgreSqlContainer('tensorchord/pgvecto-rs:pg14-v0.2.0')
.withDatabase('immich') .withDatabase('immich')
.withUsername('postgres') .withUsername('postgres')
.withPassword('postgres') .withPassword('postgres')

View file

@ -21,7 +21,7 @@ services:
- database - database
database: database:
image: tensorchord/pgvecto-rs:pg14-v0.1.11@sha256:0335a1a22f8c5dd1b697f14f079934f5152eaaa216c09b61e293be285491f8ee image: tensorchord/pgvecto-rs:pg14-v0.2.0
command: -c fsync=off -c shared_preload_libraries=vectors.so command: -c fsync=off -c shared_preload_libraries=vectors.so
environment: environment:
POSTGRES_PASSWORD: postgres POSTGRES_PASSWORD: postgres

View file

@ -25,7 +25,7 @@ export default async () => {
if (process.env.DB_HOSTNAME === undefined) { if (process.env.DB_HOSTNAME === undefined) {
// DB hostname not set which likely means we're not running e2e through docker compose. Start a local postgres container. // DB hostname not set which likely means we're not running e2e through docker compose. Start a local postgres container.
const pg = await new PostgreSqlContainer('tensorchord/pgvecto-rs:pg14-v0.1.11') const pg = await new PostgreSqlContainer('tensorchord/pgvecto-rs:pg14-v0.2.0')
.withExposedPorts(5432) .withExposedPorts(5432)
.withDatabase('immich') .withDatabase('immich')
.withUsername('postgres') .withUsername('postgres')

View file

@ -1,41 +1,65 @@
import { DatabaseExtension, DatabaseService, IDatabaseRepository, Version } from '@app/domain'; import {
DatabaseExtension,
DatabaseService,
IDatabaseRepository,
VectorIndex,
Version,
VersionType,
} from '@app/domain';
import { ImmichLogger } from '@app/infra/logger'; import { ImmichLogger } from '@app/infra/logger';
import { newDatabaseRepositoryMock } from '@test'; import { newDatabaseRepositoryMock } from '@test';
describe(DatabaseService.name, () => { describe(DatabaseService.name, () => {
let sut: DatabaseService; let sut: DatabaseService;
let databaseMock: jest.Mocked<IDatabaseRepository>; let databaseMock: jest.Mocked<IDatabaseRepository>;
let fatalLog: jest.SpyInstance;
beforeEach(async () => { beforeEach(async () => {
databaseMock = newDatabaseRepositoryMock(); databaseMock = newDatabaseRepositoryMock();
fatalLog = jest.spyOn(ImmichLogger.prototype, 'fatal');
sut = new DatabaseService(databaseMock); sut = new DatabaseService(databaseMock);
sut.minVectorsVersion = new Version(0, 1, 1);
sut.maxVectorsVersion = new Version(0, 1, 11);
});
afterEach(() => {
fatalLog.mockRestore();
}); });
it('should work', () => { it('should work', () => {
expect(sut).toBeDefined(); expect(sut).toBeDefined();
}); });
describe('init', () => { describe.each([
it('should resolve successfully if minimum supported PostgreSQL and vectors version are installed', async () => { [{ vectorExt: DatabaseExtension.VECTORS, extName: 'pgvecto.rs', minVersion: new Version(0, 1, 1) }],
[{ vectorExt: DatabaseExtension.VECTOR, extName: 'pgvector', minVersion: new Version(0, 5, 0) }],
] as const)('init', ({ vectorExt, extName, minVersion }) => {
let fatalLog: jest.SpyInstance;
let errorLog: jest.SpyInstance;
let warnLog: jest.SpyInstance;
beforeEach(async () => {
fatalLog = jest.spyOn(ImmichLogger.prototype, 'fatal');
errorLog = jest.spyOn(ImmichLogger.prototype, 'error');
warnLog = jest.spyOn(ImmichLogger.prototype, 'warn');
databaseMock.getPreferredVectorExtension.mockReturnValue(vectorExt);
databaseMock.getExtensionVersion.mockResolvedValue(minVersion);
sut = new DatabaseService(databaseMock);
sut.minVectorVersion = minVersion;
sut.minVectorsVersion = minVersion;
sut.vectorVersionPin = VersionType.MINOR;
sut.vectorsVersionPin = VersionType.MINOR;
});
afterEach(() => {
fatalLog.mockRestore();
warnLog.mockRestore();
});
it(`should resolve successfully if minimum supported PostgreSQL and ${extName} version are installed`, async () => {
databaseMock.getPostgresVersion.mockResolvedValueOnce(new Version(14, 0, 0)); databaseMock.getPostgresVersion.mockResolvedValueOnce(new Version(14, 0, 0));
databaseMock.getExtensionVersion.mockResolvedValueOnce(new Version(0, 1, 1));
await expect(sut.init()).resolves.toBeUndefined(); await expect(sut.init()).resolves.toBeUndefined();
expect(databaseMock.getPostgresVersion).toHaveBeenCalledTimes(2); expect(databaseMock.getPostgresVersion).toHaveBeenCalled();
expect(databaseMock.createExtension).toHaveBeenCalledWith(DatabaseExtension.VECTORS); expect(databaseMock.createExtension).toHaveBeenCalledWith(vectorExt);
expect(databaseMock.createExtension).toHaveBeenCalledTimes(1); expect(databaseMock.createExtension).toHaveBeenCalledTimes(1);
expect(databaseMock.getExtensionVersion).toHaveBeenCalledTimes(1); expect(databaseMock.getExtensionVersion).toHaveBeenCalled();
expect(databaseMock.runMigrations).toHaveBeenCalledTimes(1); expect(databaseMock.runMigrations).toHaveBeenCalledTimes(1);
expect(fatalLog).not.toHaveBeenCalled(); expect(fatalLog).not.toHaveBeenCalled();
}); });
@ -43,112 +67,162 @@ describe(DatabaseService.name, () => {
it('should throw an error if PostgreSQL version is below minimum supported version', async () => { it('should throw an error if PostgreSQL version is below minimum supported version', async () => {
databaseMock.getPostgresVersion.mockResolvedValueOnce(new Version(13, 0, 0)); databaseMock.getPostgresVersion.mockResolvedValueOnce(new Version(13, 0, 0));
await expect(sut.init()).rejects.toThrow(/PostgreSQL version is 13/s); await expect(sut.init()).rejects.toThrow('PostgreSQL version is 13');
expect(databaseMock.getPostgresVersion).toHaveBeenCalledTimes(1); expect(databaseMock.getPostgresVersion).toHaveBeenCalledTimes(1);
}); });
it('should resolve successfully if minimum supported vectors version is installed', async () => { it(`should resolve successfully if minimum supported ${extName} version is installed`, async () => {
databaseMock.getExtensionVersion.mockResolvedValueOnce(new Version(0, 1, 1));
await expect(sut.init()).resolves.toBeUndefined(); await expect(sut.init()).resolves.toBeUndefined();
expect(databaseMock.createExtension).toHaveBeenCalledWith(DatabaseExtension.VECTORS); expect(databaseMock.createExtension).toHaveBeenCalledWith(vectorExt);
expect(databaseMock.createExtension).toHaveBeenCalledTimes(1); expect(databaseMock.createExtension).toHaveBeenCalledTimes(1);
expect(databaseMock.getExtensionVersion).toHaveBeenCalledTimes(1);
expect(databaseMock.runMigrations).toHaveBeenCalledTimes(1); expect(databaseMock.runMigrations).toHaveBeenCalledTimes(1);
expect(fatalLog).not.toHaveBeenCalled(); expect(fatalLog).not.toHaveBeenCalled();
}); });
it('should resolve successfully if maximum supported vectors version is installed', async () => { it(`should throw an error if ${extName} version is not installed even after createVectorExtension`, async () => {
databaseMock.getExtensionVersion.mockResolvedValueOnce(new Version(0, 1, 11)); databaseMock.getExtensionVersion.mockResolvedValue(null);
await expect(sut.init()).resolves.toBeUndefined(); await expect(sut.init()).rejects.toThrow(`Unexpected: ${extName} extension is not installed.`);
expect(databaseMock.createExtension).toHaveBeenCalledWith(DatabaseExtension.VECTORS);
expect(databaseMock.createExtension).toHaveBeenCalledTimes(1);
expect(databaseMock.getExtensionVersion).toHaveBeenCalledTimes(1);
expect(databaseMock.runMigrations).toHaveBeenCalledTimes(1);
expect(fatalLog).not.toHaveBeenCalled();
});
it('should throw an error if vectors version is not installed even after createVectors', async () => {
databaseMock.getExtensionVersion.mockResolvedValueOnce(null);
await expect(sut.init()).rejects.toThrow('Unexpected: The pgvecto.rs extension is not installed.');
expect(databaseMock.getExtensionVersion).toHaveBeenCalledTimes(1);
expect(databaseMock.createExtension).toHaveBeenCalledTimes(1); expect(databaseMock.createExtension).toHaveBeenCalledTimes(1);
expect(databaseMock.runMigrations).not.toHaveBeenCalled(); expect(databaseMock.runMigrations).not.toHaveBeenCalled();
}); });
it('should throw an error if vectors version is below minimum supported version', async () => { it(`should throw an error if ${extName} version is below minimum supported version`, async () => {
databaseMock.getExtensionVersion.mockResolvedValueOnce(new Version(0, 0, 1)); databaseMock.getExtensionVersion.mockResolvedValue(
new Version(minVersion.major, minVersion.minor - 1, minVersion.patch),
await expect(sut.init()).rejects.toThrow(/('tensorchord\/pgvecto-rs:pg14-v0.1.11')/s);
expect(databaseMock.getExtensionVersion).toHaveBeenCalledTimes(1);
expect(databaseMock.runMigrations).not.toHaveBeenCalled();
});
it('should throw an error if vectors version is above maximum supported version', async () => {
databaseMock.getExtensionVersion.mockResolvedValueOnce(new Version(0, 1, 12));
await expect(sut.init()).rejects.toThrow(
/('DROP EXTENSION IF EXISTS vectors').*('tensorchord\/pgvecto-rs:pg14-v0\.1\.11')/s,
); );
expect(databaseMock.getExtensionVersion).toHaveBeenCalledTimes(1); await expect(sut.init()).rejects.toThrow(extName);
expect(databaseMock.runMigrations).not.toHaveBeenCalled(); expect(databaseMock.runMigrations).not.toHaveBeenCalled();
}); });
it('should throw an error if vectors version is a nightly', async () => { it.each([
databaseMock.getExtensionVersion.mockResolvedValueOnce(new Version(0, 0, 0)); { type: VersionType.EQUAL, max: 'no', actual: 'patch' },
{ type: VersionType.PATCH, max: 'patch', actual: 'minor' },
{ type: VersionType.MINOR, max: 'minor', actual: 'major' },
] as const)(
`should throw an error if $max upgrade from min version is allowed and ${extName} version is $actual`,
async ({ type, actual }) => {
const version = new Version(minVersion.major, minVersion.minor, minVersion.patch);
version[actual] = minVersion[actual] + 1;
databaseMock.getExtensionVersion.mockResolvedValue(version);
if (vectorExt === DatabaseExtension.VECTOR) {
sut.minVectorVersion = minVersion;
sut.vectorVersionPin = type;
} else {
sut.minVectorsVersion = minVersion;
sut.vectorsVersionPin = type;
}
await expect(sut.init()).rejects.toThrow( await expect(sut.init()).rejects.toThrow(extName);
/(nightly).*('DROP EXTENSION IF EXISTS vectors').*('tensorchord\/pgvecto-rs:pg14-v0\.1\.11')/s,
expect(databaseMock.runMigrations).not.toHaveBeenCalled();
},
); );
expect(databaseMock.getExtensionVersion).toHaveBeenCalledTimes(1); it(`should throw an error if ${extName} version is a nightly`, async () => {
databaseMock.getExtensionVersion.mockResolvedValue(new Version(0, 0, 0));
await expect(sut.init()).rejects.toThrow(extName);
expect(databaseMock.createExtension).toHaveBeenCalledTimes(1); expect(databaseMock.createExtension).toHaveBeenCalledTimes(1);
expect(databaseMock.runMigrations).not.toHaveBeenCalled(); expect(databaseMock.runMigrations).not.toHaveBeenCalled();
}); });
it('should throw error if vectors extension could not be created', async () => { it(`should throw error if ${extName} extension could not be created`, async () => {
databaseMock.createExtension.mockRejectedValueOnce(new Error('Failed to create extension')); databaseMock.createExtension.mockRejectedValue(new Error('Failed to create extension'));
await expect(sut.init()).rejects.toThrow('Failed to create extension'); await expect(sut.init()).rejects.toThrow('Failed to create extension');
expect(fatalLog).toHaveBeenCalledTimes(1); expect(fatalLog).toHaveBeenCalledTimes(1);
expect(fatalLog.mock.calls[0][0]).toMatch(/('tensorchord\/pgvecto-rs:pg14-v0\.1\.11').*(v1\.91\.0)/s);
expect(databaseMock.createExtension).toHaveBeenCalledTimes(1); expect(databaseMock.createExtension).toHaveBeenCalledTimes(1);
expect(databaseMock.runMigrations).not.toHaveBeenCalled(); expect(databaseMock.runMigrations).not.toHaveBeenCalled();
}); });
it.each([{ major: 14 }, { major: 15 }, { major: 16 }])( it(`should update ${extName} if a newer version is available`, async () => {
`should suggest image with postgres $major if database is $major`, const version = new Version(minVersion.major, minVersion.minor + 1, minVersion.patch);
async ({ major }) => { databaseMock.getAvailableExtensionVersion.mockResolvedValue(version);
databaseMock.getExtensionVersion.mockResolvedValue(new Version(0, 0, 1));
databaseMock.getPostgresVersion.mockResolvedValue(new Version(major, 0, 0));
await expect(sut.init()).rejects.toThrow(new RegExp(`tensorchord\/pgvecto-rs:pg${major}-v0\\.1\\.11`, 's')); await expect(sut.init()).resolves.toBeUndefined();
expect(databaseMock.updateVectorExtension).toHaveBeenCalledWith(vectorExt, version);
expect(databaseMock.updateVectorExtension).toHaveBeenCalledTimes(1);
expect(databaseMock.runMigrations).toHaveBeenCalledTimes(1);
expect(fatalLog).not.toHaveBeenCalled();
});
it(`should not update ${extName} if a newer version is higher than the maximum`, async () => {
const version = new Version(minVersion.major + 1, minVersion.minor, minVersion.patch);
databaseMock.getAvailableExtensionVersion.mockResolvedValue(version);
await expect(sut.init()).resolves.toBeUndefined();
expect(databaseMock.updateVectorExtension).not.toHaveBeenCalled();
expect(databaseMock.runMigrations).toHaveBeenCalledTimes(1);
expect(fatalLog).not.toHaveBeenCalled();
});
it(`should warn if attempted to update ${extName} and failed`, async () => {
const version = new Version(minVersion.major, minVersion.minor, minVersion.patch + 1);
databaseMock.getAvailableExtensionVersion.mockResolvedValue(version);
databaseMock.updateVectorExtension.mockRejectedValue(new Error('Failed to update extension'));
await expect(sut.init()).resolves.toBeUndefined();
expect(warnLog).toHaveBeenCalledTimes(1);
expect(warnLog.mock.calls[0][0]).toContain(extName);
expect(errorLog).toHaveBeenCalledTimes(1);
expect(fatalLog).not.toHaveBeenCalled();
expect(databaseMock.updateVectorExtension).toHaveBeenCalledWith(vectorExt, version);
expect(databaseMock.runMigrations).toHaveBeenCalledTimes(1);
});
it(`should warn if ${extName} update requires restart`, async () => {
const version = new Version(minVersion.major, minVersion.minor, minVersion.patch + 1);
databaseMock.getAvailableExtensionVersion.mockResolvedValue(version);
databaseMock.updateVectorExtension.mockResolvedValue({ restartRequired: true });
await expect(sut.init()).resolves.toBeUndefined();
expect(warnLog).toHaveBeenCalledTimes(1);
expect(warnLog.mock.calls[0][0]).toContain(extName);
expect(databaseMock.updateVectorExtension).toHaveBeenCalledWith(vectorExt, version);
expect(databaseMock.runMigrations).toHaveBeenCalledTimes(1);
expect(fatalLog).not.toHaveBeenCalled();
});
it.each([{ index: VectorIndex.CLIP }, { index: VectorIndex.FACE }])(
`should reindex $index if necessary`,
async ({ index }) => {
databaseMock.shouldReindex.mockImplementation((indexArg) => Promise.resolve(indexArg === index));
await expect(sut.init()).resolves.toBeUndefined();
expect(databaseMock.shouldReindex).toHaveBeenCalledWith(index);
expect(databaseMock.shouldReindex).toHaveBeenCalledTimes(2);
expect(databaseMock.reindex).toHaveBeenCalledWith(index);
expect(databaseMock.reindex).toHaveBeenCalledTimes(1);
expect(databaseMock.runMigrations).toHaveBeenCalledTimes(1);
expect(fatalLog).not.toHaveBeenCalled();
}, },
); );
it('should not suggest image if postgres version is not in 14, 15 or 16', async () => { it.each([{ index: VectorIndex.CLIP }, { index: VectorIndex.FACE }])(
databaseMock.getPostgresVersion.mockResolvedValueOnce(new Version(17, 0, 0)); `should not reindex $index if not necessary`,
databaseMock.getPostgresVersion.mockResolvedValueOnce(new Version(17, 0, 0)); async () => {
databaseMock.shouldReindex.mockResolvedValue(false);
await expect(sut.init()).rejects.toThrow(/^(?:(?!tensorchord\/pgvecto-rs).)*$/s); await expect(sut.init()).resolves.toBeUndefined();
});
it('should reject and suggest the maximum supported version when unsupported pgvecto.rs version is in use', async () => { expect(databaseMock.shouldReindex).toHaveBeenCalledTimes(2);
databaseMock.getExtensionVersion.mockResolvedValue(new Version(0, 0, 1)); expect(databaseMock.reindex).not.toHaveBeenCalled();
expect(databaseMock.runMigrations).toHaveBeenCalledTimes(1);
await expect(sut.init()).rejects.toThrow(/('tensorchord\/pgvecto-rs:pg14-v0\.1\.11')/s); expect(fatalLog).not.toHaveBeenCalled();
},
sut.maxVectorsVersion = new Version(0, 1, 12); );
await expect(sut.init()).rejects.toThrow(/('tensorchord\/pgvecto-rs:pg14-v0\.1\.12')/s);
});
}); });
}); });

View file

@ -1,72 +1,54 @@
import { ImmichLogger } from '@app/infra/logger'; import { ImmichLogger } from '@app/infra/logger';
import { Inject, Injectable } from '@nestjs/common'; import { Inject, Injectable } from '@nestjs/common';
import { QueryFailedError } from 'typeorm'; import { QueryFailedError } from 'typeorm';
import { Version } from '../domain.constant'; import { Version, VersionType } from '../domain.constant';
import { DatabaseExtension, IDatabaseRepository } from '../repositories'; import {
DatabaseExtension,
DatabaseLock,
IDatabaseRepository,
VectorExtension,
VectorIndex,
extName,
} from '../repositories';
@Injectable() @Injectable()
export class DatabaseService { export class DatabaseService {
private logger = new ImmichLogger(DatabaseService.name); private logger = new ImmichLogger(DatabaseService.name);
private vectorExt: VectorExtension;
minPostgresVersion = 14; minPostgresVersion = 14;
minVectorsVersion = new Version(0, 1, 1); minVectorsVersion = new Version(0, 2, 0);
maxVectorsVersion = new Version(0, 1, 11); vectorsVersionPin = VersionType.MINOR;
minVectorVersion = new Version(0, 5, 0);
vectorVersionPin = VersionType.MAJOR;
constructor(@Inject(IDatabaseRepository) private databaseRepository: IDatabaseRepository) {} constructor(@Inject(IDatabaseRepository) private databaseRepository: IDatabaseRepository) {
this.vectorExt = this.databaseRepository.getPreferredVectorExtension();
}
async init() { async init() {
await this.assertPostgresql(); await this.assertPostgresql();
await this.createVectors(); await this.databaseRepository.withLock(DatabaseLock.Migrations, async () => {
await this.assertVectors(); await this.createVectorExtension();
await this.databaseRepository.runMigrations(); await this.updateVectorExtension();
await this.assertVectorExtension();
try {
if (await this.databaseRepository.shouldReindex(VectorIndex.CLIP)) {
await this.databaseRepository.reindex(VectorIndex.CLIP);
} }
private async assertVectors() { if (await this.databaseRepository.shouldReindex(VectorIndex.FACE)) {
const version = await this.databaseRepository.getExtensionVersion(DatabaseExtension.VECTORS); await this.databaseRepository.reindex(VectorIndex.FACE);
if (version == null) {
throw new Error('Unexpected: The pgvecto.rs extension is not installed.');
} }
} catch (error) {
const image = await this.getVectorsImage(); this.logger.warn(
const suggestion = image ? `, such as with the docker image '${image}'` : ''; 'Could not run vector reindexing checks. If the extension was updated, please restart the Postgres instance.',
if (version.isEqual(new Version(0, 0, 0))) {
throw new Error(
`The pgvecto.rs extension version is ${version}, which means it is a nightly release.` +
`Please run 'DROP EXTENSION IF EXISTS vectors' and switch to a release version${suggestion}.`,
); );
}
if (version.isNewerThan(this.maxVectorsVersion)) {
throw new Error(`
The pgvecto.rs extension version is ${version} instead of ${this.maxVectorsVersion}.
Please run 'DROP EXTENSION IF EXISTS vectors' and switch to ${this.maxVectorsVersion}${suggestion}.`);
}
if (version.isOlderThan(this.minVectorsVersion)) {
throw new Error(`
The pgvecto.rs extension version is ${version}, which is older than the minimum supported version ${this.minVectorsVersion}.
Please upgrade to this version or later${suggestion}.`);
}
}
private async createVectors() {
await this.databaseRepository.createExtension(DatabaseExtension.VECTORS).catch(async (error: QueryFailedError) => {
const image = await this.getVectorsImage();
this.logger.fatal(`
Failed to create pgvecto.rs extension.
If you have not updated your Postgres instance to a docker image that supports pgvecto.rs (such as '${image}'), please do so.
See the v1.91.0 release notes for more info: https://github.com/immich-app/immich/releases/tag/v1.91.0'
`);
throw error; throw error;
});
} }
private async getVectorsImage() { await this.databaseRepository.runMigrations();
const { major } = await this.databaseRepository.getPostgresVersion(); });
if (![14, 15, 16].includes(major)) {
return null;
}
return `tensorchord/pgvecto-rs:pg${major}-v${this.maxVectorsVersion}`;
} }
private async assertPostgresql() { private async assertPostgresql() {
@ -77,4 +59,99 @@ export class DatabaseService {
Please upgrade to this version or later.`); Please upgrade to this version or later.`);
} }
} }
private async createVectorExtension() {
await this.databaseRepository.createExtension(this.vectorExt).catch(async (error: QueryFailedError) => {
const otherExt =
this.vectorExt === DatabaseExtension.VECTORS ? DatabaseExtension.VECTOR : DatabaseExtension.VECTORS;
this.logger.fatal(`
Failed to activate ${extName[this.vectorExt]} extension.
Please ensure the Postgres instance has ${extName[this.vectorExt]} installed.
If the Postgres instance already has ${extName[this.vectorExt]} installed, Immich may not have the necessary permissions to activate it.
In this case, please run 'CREATE EXTENSION IF NOT EXISTS ${this.vectorExt}' manually as a superuser.
See https://immich.app/docs/guides/database-queries for how to query the database.
Alternatively, if your Postgres instance has ${extName[otherExt]}, you may use this instead by setting the environment variable 'VECTOR_EXTENSION=${otherExt}'.
Note that switching between the two extensions after a successful startup is not supported.
The exception is if your version of Immich prior to upgrading was 1.90.2 or earlier.
In this case, you may set either extension now, but you will not be able to switch to the other extension following a successful startup.
`);
throw error;
});
}
private async updateVectorExtension() {
const [version, availableVersion] = await Promise.all([
this.databaseRepository.getExtensionVersion(this.vectorExt),
this.databaseRepository.getAvailableExtensionVersion(this.vectorExt),
]);
if (version == null) {
throw new Error(`Unexpected: ${extName[this.vectorExt]} extension is not installed.`);
}
if (availableVersion == null) {
return;
}
const maxVersion = this.vectorExt === DatabaseExtension.VECTOR ? this.vectorVersionPin : this.vectorsVersionPin;
const isNewer = availableVersion.isNewerThan(version);
if (isNewer == null || isNewer > maxVersion) {
return;
}
try {
this.logger.log(`Updating ${extName[this.vectorExt]} extension to ${availableVersion}`);
const { restartRequired } = await this.databaseRepository.updateVectorExtension(this.vectorExt, availableVersion);
if (restartRequired) {
this.logger.warn(`
The ${extName[this.vectorExt]} extension has been updated to ${availableVersion}.
Please restart the Postgres instance to complete the update.`);
}
} catch (error) {
this.logger.warn(`
The ${extName[this.vectorExt]} extension version is ${version}, but ${availableVersion} is available.
Immich attempted to update the extension, but failed to do so.
This may be because Immich does not have the necessary permissions to update the extension.
Please run 'ALTER EXTENSION ${this.vectorExt} UPDATE' manually as a superuser.
See https://immich.app/docs/guides/database-queries for how to query the database.`);
this.logger.error(error);
}
}
private async assertVectorExtension() {
const version = await this.databaseRepository.getExtensionVersion(this.vectorExt);
if (version == null) {
throw new Error(`Unexpected: The ${extName[this.vectorExt]} extension is not installed.`);
}
if (version.isEqual(new Version(0, 0, 0))) {
throw new Error(`
The ${extName[this.vectorExt]} extension version is ${version}, which means it is a nightly release.
Please run 'DROP EXTENSION IF EXISTS ${this.vectorExt}' and switch to a release version.
See https://immich.app/docs/guides/database-queries for how to query the database.`);
}
const minVersion = this.vectorExt === DatabaseExtension.VECTOR ? this.minVectorVersion : this.minVectorsVersion;
const maxVersion = this.vectorExt === DatabaseExtension.VECTOR ? this.vectorVersionPin : this.vectorsVersionPin;
if (version.isOlderThan(minVersion) || version.isNewerThan(minVersion) > maxVersion) {
const allowedReleaseType = maxVersion === VersionType.MAJOR ? '' : ` ${VersionType[maxVersion].toLowerCase()}`;
const releases =
maxVersion === VersionType.EQUAL
? minVersion.toString()
: `${minVersion} and later${allowedReleaseType} releases`;
throw new Error(`
The ${extName[this.vectorExt]} extension version is ${version}, but Immich only supports ${releases}.
If the Postgres instance already has a compatible version installed, Immich may not have the necessary permissions to activate it.
In this case, please run 'ALTER EXTENSION UPDATE ${this.vectorExt}' manually as a superuser.
See https://immich.app/docs/guides/database-queries for how to query the database.
Otherwise, please update the version of ${extName[this.vectorExt]} in the Postgres instance to a compatible version.`);
}
}
} }

View file

@ -24,5 +24,6 @@ export const immichAppConfig: ConfigModuleOptions = {
MACHINE_LEARNING_PORT: Joi.number().optional(), MACHINE_LEARNING_PORT: Joi.number().optional(),
MICROSERVICES_PORT: Joi.number().optional(), MICROSERVICES_PORT: Joi.number().optional(),
SERVER_PORT: Joi.number().optional(), SERVER_PORT: Joi.number().optional(),
VECTOR_EXTENSION: Joi.string().optional().valid('pgvector', 'pgvecto.rs').default('pgvecto.rs'),
}), }),
}; };

View file

@ -1,4 +1,4 @@
import { Version, mimeTypes } from './domain.constant'; import { Version, VersionType, mimeTypes } from './domain.constant';
describe('mimeTypes', () => { describe('mimeTypes', () => {
for (const { mimetype, extension } of [ for (const { mimetype, extension } of [
@ -196,45 +196,37 @@ describe('mimeTypes', () => {
}); });
}); });
describe('ServerVersion', () => { describe('Version', () => {
const tests = [ const tests = [
{ this: new Version(0, 0, 1), other: new Version(0, 0, 0), expected: 1 }, { this: new Version(0, 0, 1), other: new Version(0, 0, 0), compare: 1, type: VersionType.PATCH },
{ this: new Version(0, 1, 0), other: new Version(0, 0, 0), expected: 1 }, { this: new Version(0, 1, 0), other: new Version(0, 0, 0), compare: 1, type: VersionType.MINOR },
{ this: new Version(1, 0, 0), other: new Version(0, 0, 0), expected: 1 }, { this: new Version(1, 0, 0), other: new Version(0, 0, 0), compare: 1, type: VersionType.MAJOR },
{ this: new Version(0, 0, 0), other: new Version(0, 0, 1), expected: -1 }, { this: new Version(0, 0, 0), other: new Version(0, 0, 1), compare: -1, type: VersionType.PATCH },
{ this: new Version(0, 0, 0), other: new Version(0, 1, 0), expected: -1 }, { this: new Version(0, 0, 0), other: new Version(0, 1, 0), compare: -1, type: VersionType.MINOR },
{ this: new Version(0, 0, 0), other: new Version(1, 0, 0), expected: -1 }, { this: new Version(0, 0, 0), other: new Version(1, 0, 0), compare: -1, type: VersionType.MAJOR },
{ this: new Version(0, 0, 0), other: new Version(0, 0, 0), expected: 0 }, { this: new Version(0, 0, 0), other: new Version(0, 0, 0), compare: 0, type: VersionType.EQUAL },
{ this: new Version(0, 0, 1), other: new Version(0, 0, 1), expected: 0 }, { this: new Version(0, 0, 1), other: new Version(0, 0, 1), compare: 0, type: VersionType.EQUAL },
{ this: new Version(0, 1, 0), other: new Version(0, 1, 0), expected: 0 }, { this: new Version(0, 1, 0), other: new Version(0, 1, 0), compare: 0, type: VersionType.EQUAL },
{ this: new Version(1, 0, 0), other: new Version(1, 0, 0), expected: 0 }, { this: new Version(1, 0, 0), other: new Version(1, 0, 0), compare: 0, type: VersionType.EQUAL },
{ this: new Version(1, 0), other: new Version(1, 0, 0), expected: 0 }, { this: new Version(1, 0), other: new Version(1, 0, 0), compare: 0, type: VersionType.EQUAL },
{ this: new Version(1, 0), other: new Version(1, 0, 1), expected: -1 }, { this: new Version(1, 0), other: new Version(1, 0, 1), compare: -1, type: VersionType.PATCH },
{ this: new Version(1, 1), other: new Version(1, 0, 1), expected: 1 }, { this: new Version(1, 1), other: new Version(1, 0, 1), compare: 1, type: VersionType.MINOR },
{ this: new Version(1), other: new Version(1, 0, 0), expected: 0 }, { this: new Version(1), other: new Version(1, 0, 0), compare: 0, type: VersionType.EQUAL },
{ this: new Version(1), other: new Version(1, 0, 1), expected: -1 }, { this: new Version(1), other: new Version(1, 0, 1), compare: -1, type: VersionType.PATCH },
]; ];
describe('compare', () => {
for (const { this: thisVersion, other: otherVersion, expected } of tests) {
it(`should return ${expected} when comparing ${thisVersion} to ${otherVersion}`, () => {
expect(thisVersion.compare(otherVersion)).toEqual(expected);
});
}
});
describe('isOlderThan', () => { describe('isOlderThan', () => {
for (const { this: thisVersion, other: otherVersion, expected } of tests) { for (const { this: thisVersion, other: otherVersion, compare, type } of tests) {
const bool = expected < 0; const expected = compare < 0 ? type : VersionType.EQUAL;
it(`should return ${bool} when comparing ${thisVersion} to ${otherVersion}`, () => { it(`should return '${expected}' when comparing ${thisVersion} to ${otherVersion}`, () => {
expect(thisVersion.isOlderThan(otherVersion)).toEqual(bool); expect(thisVersion.isOlderThan(otherVersion)).toEqual(expected);
}); });
} }
}); });
describe('isEqual', () => { describe('isEqual', () => {
for (const { this: thisVersion, other: otherVersion, expected } of tests) { for (const { this: thisVersion, other: otherVersion, compare } of tests) {
const bool = expected === 0; const bool = compare === 0;
it(`should return ${bool} when comparing ${thisVersion} to ${otherVersion}`, () => { it(`should return ${bool} when comparing ${thisVersion} to ${otherVersion}`, () => {
expect(thisVersion.isEqual(otherVersion)).toEqual(bool); expect(thisVersion.isEqual(otherVersion)).toEqual(bool);
}); });
@ -242,10 +234,10 @@ describe('ServerVersion', () => {
}); });
describe('isNewerThan', () => { describe('isNewerThan', () => {
for (const { this: thisVersion, other: otherVersion, expected } of tests) { for (const { this: thisVersion, other: otherVersion, compare, type } of tests) {
const bool = expected > 0; const expected = compare > 0 ? type : VersionType.EQUAL;
it(`should return ${bool} when comparing ${thisVersion} to ${otherVersion}`, () => { it(`should return ${expected} when comparing ${thisVersion} to ${otherVersion}`, () => {
expect(thisVersion.isNewerThan(otherVersion)).toEqual(bool); expect(thisVersion.isNewerThan(otherVersion)).toEqual(expected);
}); });
} }
}); });

View file

@ -12,11 +12,20 @@ export interface IVersion {
patch: number; patch: number;
} }
export enum VersionType {
EQUAL = 0,
PATCH = 1,
MINOR = 2,
MAJOR = 3,
}
export class Version implements IVersion { export class Version implements IVersion {
public readonly types = ['major', 'minor', 'patch'] as const;
constructor( constructor(
public readonly major: number, public major: number,
public readonly minor: number = 0, public minor: number = 0,
public readonly patch: number = 0, public patch: number = 0,
) {} ) {}
toString() { toString() {
@ -39,27 +48,30 @@ export class Version implements IVersion {
} }
} }
compare(version: Version): number { private compare(version: Version): [number, VersionType] {
for (const key of ['major', 'minor', 'patch'] as const) { for (const [i, key] of this.types.entries()) {
const diff = this[key] - version[key]; const diff = this[key] - version[key];
if (diff !== 0) { if (diff !== 0) {
return diff > 0 ? 1 : -1; return [diff > 0 ? 1 : -1, (VersionType.MAJOR - i) as VersionType];
} }
} }
return 0; return [0, VersionType.EQUAL];
} }
isOlderThan(version: Version): boolean { isOlderThan(version: Version): VersionType {
return this.compare(version) < 0; const [bool, type] = this.compare(version);
return bool < 0 ? type : VersionType.EQUAL;
} }
isEqual(version: Version): boolean { isEqual(version: Version): boolean {
return this.compare(version) === 0; const [bool] = this.compare(version);
return bool === 0;
} }
isNewerThan(version: Version): boolean { isNewerThan(version: Version): VersionType {
return this.compare(version) > 0; const [bool, type] = this.compare(version);
return bool > 0 ? type : VersionType.EQUAL;
} }
} }

View file

@ -3,21 +3,47 @@ import { Version } from '../domain.constant';
export enum DatabaseExtension { export enum DatabaseExtension {
CUBE = 'cube', CUBE = 'cube',
EARTH_DISTANCE = 'earthdistance', EARTH_DISTANCE = 'earthdistance',
VECTOR = 'vector',
VECTORS = 'vectors', VECTORS = 'vectors',
} }
export type VectorExtension = DatabaseExtension.VECTOR | DatabaseExtension.VECTORS;
export enum VectorIndex {
CLIP = 'clip_index',
FACE = 'face_index',
}
export enum DatabaseLock { export enum DatabaseLock {
GeodataImport = 100, GeodataImport = 100,
Migrations = 200,
StorageTemplateMigration = 420, StorageTemplateMigration = 420,
CLIPDimSize = 512, CLIPDimSize = 512,
} }
export const extName: Record<DatabaseExtension, string> = {
cube: 'cube',
earthdistance: 'earthdistance',
vector: 'pgvector',
vectors: 'pgvecto.rs',
} as const;
export interface VectorUpdateResult {
restartRequired: boolean;
}
export const IDatabaseRepository = 'IDatabaseRepository'; export const IDatabaseRepository = 'IDatabaseRepository';
export interface IDatabaseRepository { export interface IDatabaseRepository {
getExtensionVersion(extensionName: string): Promise<Version | null>; getExtensionVersion(extensionName: string): Promise<Version | null>;
getAvailableExtensionVersion(extension: DatabaseExtension): Promise<Version | null>;
getPreferredVectorExtension(): VectorExtension;
getPostgresVersion(): Promise<Version>; getPostgresVersion(): Promise<Version>;
createExtension(extension: DatabaseExtension): Promise<void>; createExtension(extension: DatabaseExtension): Promise<void>;
updateExtension(extension: DatabaseExtension, version?: Version): Promise<void>;
updateVectorExtension(extension: VectorExtension, version?: Version): Promise<VectorUpdateResult>;
reindex(index: VectorIndex): Promise<void>;
shouldReindex(name: VectorIndex): Promise<boolean>;
runMigrations(options?: { transaction?: 'all' | 'none' | 'each' }): Promise<void>; runMigrations(options?: { transaction?: 'all' | 'none' | 'each' }): Promise<void>;
withLock<R>(lock: DatabaseLock, callback: () => Promise<R>): Promise<R>; withLock<R>(lock: DatabaseLock, callback: () => Promise<R>): Promise<R>;
isBusy(lock: DatabaseLock): boolean; isBusy(lock: DatabaseLock): boolean;

View file

@ -7,7 +7,7 @@ export type Embedding = number[];
export interface EmbeddingSearch { export interface EmbeddingSearch {
userIds: string[]; userIds: string[];
embedding: Embedding; embedding: Embedding;
numResults?: number; numResults: number;
withArchived?: boolean; withArchived?: boolean;
} }

View file

@ -1,3 +1,4 @@
import { DatabaseExtension } from '@app/domain/repositories/database.repository';
import { DataSource } from 'typeorm'; import { DataSource } from 'typeorm';
import { PostgresConnectionOptions } from 'typeorm/driver/postgres/PostgresConnectionOptions.js'; import { PostgresConnectionOptions } from 'typeorm/driver/postgres/PostgresConnectionOptions.js';
@ -27,3 +28,6 @@ export const databaseConfig: PostgresConnectionOptions = {
// this export is used by TypeORM commands in package.json#scripts // this export is used by TypeORM commands in package.json#scripts
export const dataSource = new DataSource(databaseConfig); export const dataSource = new DataSource(databaseConfig);
export const vectorExt =
process.env.VECTOR_EXTENSION === 'pgvector' ? DatabaseExtension.VECTOR : DatabaseExtension.VECTORS;

View file

@ -5,7 +5,7 @@ import { LogLevel } from './entities';
const LOG_LEVELS = [LogLevel.VERBOSE, LogLevel.DEBUG, LogLevel.LOG, LogLevel.WARN, LogLevel.ERROR, LogLevel.FATAL]; const LOG_LEVELS = [LogLevel.VERBOSE, LogLevel.DEBUG, LogLevel.LOG, LogLevel.WARN, LogLevel.ERROR, LogLevel.FATAL];
export class ImmichLogger extends ConsoleLogger { export class ImmichLogger extends ConsoleLogger {
private static logLevels: LogLevel[] = [LogLevel.WARN, LogLevel.ERROR, LogLevel.FATAL]; private static logLevels: LogLevel[] = [LogLevel.LOG, LogLevel.WARN, LogLevel.ERROR, LogLevel.FATAL];
constructor(context: string) { constructor(context: string) {
super(context); super(context);

View file

@ -1,11 +1,13 @@
import { getCLIPModelInfo } from '@app/domain/smart-info/smart-info.constant'; import { getCLIPModelInfo } from '@app/domain/smart-info/smart-info.constant';
import { MigrationInterface, QueryRunner } from 'typeorm'; import { MigrationInterface, QueryRunner } from 'typeorm';
import { vectorExt } from '@app/infra/database.config';
export class UsePgVectors1700713871511 implements MigrationInterface { export class UsePgVectors1700713871511 implements MigrationInterface {
name = 'UsePgVectors1700713871511'; name = 'UsePgVectors1700713871511';
public async up(queryRunner: QueryRunner): Promise<void> { public async up(queryRunner: QueryRunner): Promise<void> {
await queryRunner.query(`CREATE EXTENSION IF NOT EXISTS vectors`); await queryRunner.query(`SET search_path TO "$user", public, vectors`);
await queryRunner.query(`CREATE EXTENSION IF NOT EXISTS ${vectorExt}`);
const faceDimQuery = await queryRunner.query(` const faceDimQuery = await queryRunner.query(`
SELECT CARDINALITY(embedding::real[]) as dimsize SELECT CARDINALITY(embedding::real[]) as dimsize
FROM asset_faces FROM asset_faces

View file

@ -1,16 +1,20 @@
import { MigrationInterface, QueryRunner } from 'typeorm'; import { MigrationInterface, QueryRunner } from 'typeorm';
import { vectorExt } from '../database.config';
import { DatabaseExtension } from '@app/domain/repositories/database.repository';
export class AddCLIPEmbeddingIndex1700713994428 implements MigrationInterface { export class AddCLIPEmbeddingIndex1700713994428 implements MigrationInterface {
name = 'AddCLIPEmbeddingIndex1700713994428'; name = 'AddCLIPEmbeddingIndex1700713994428';
public async up(queryRunner: QueryRunner): Promise<void> { public async up(queryRunner: QueryRunner): Promise<void> {
if (vectorExt === DatabaseExtension.VECTORS) {
await queryRunner.query(`SET vectors.pgvector_compatibility=on`);
}
await queryRunner.query(`SET search_path TO "$user", public, vectors`);
await queryRunner.query(` await queryRunner.query(`
CREATE INDEX IF NOT EXISTS clip_index ON smart_search CREATE INDEX IF NOT EXISTS clip_index ON smart_search
USING vectors (embedding cosine_ops) WITH (options = $$ USING hnsw (embedding vector_cosine_ops)
[indexing.hnsw] WITH (ef_construction = 300, m = 16)`);
m = 16
ef_construction = 300
$$);`);
} }
public async down(queryRunner: QueryRunner): Promise<void> { public async down(queryRunner: QueryRunner): Promise<void> {

View file

@ -1,16 +1,20 @@
import { MigrationInterface, QueryRunner } from 'typeorm'; import { MigrationInterface, QueryRunner } from 'typeorm';
import { vectorExt } from '../database.config';
import { DatabaseExtension } from '@app/domain/repositories/database.repository';
export class AddFaceEmbeddingIndex1700714033632 implements MigrationInterface { export class AddFaceEmbeddingIndex1700714033632 implements MigrationInterface {
name = 'AddFaceEmbeddingIndex1700714033632'; name = 'AddFaceEmbeddingIndex1700714033632';
public async up(queryRunner: QueryRunner): Promise<void> { public async up(queryRunner: QueryRunner): Promise<void> {
if (vectorExt === DatabaseExtension.VECTORS) {
await queryRunner.query(`SET vectors.pgvector_compatibility=on`);
}
await queryRunner.query(`SET search_path TO "$user", public, vectors`);
await queryRunner.query(` await queryRunner.query(`
CREATE INDEX IF NOT EXISTS face_index ON asset_faces CREATE INDEX IF NOT EXISTS face_index ON asset_faces
USING vectors (embedding cosine_ops) WITH (options = $$ USING hnsw (embedding vector_cosine_ops)
[indexing.hnsw] WITH (ef_construction = 300, m = 16)`);
m = 16
ef_construction = 300
$$);`);
} }
public async down(queryRunner: QueryRunner): Promise<void> { public async down(queryRunner: QueryRunner): Promise<void> {

View file

@ -0,0 +1,14 @@
import { MigrationInterface, QueryRunner } from 'typeorm';
export class AddVectorsToSearchPath1707000751533 implements MigrationInterface {
public async up(queryRunner: QueryRunner): Promise<void> {
const res = await queryRunner.query(`SELECT current_database() as db`);
const databaseName = res[0]['db'];
await queryRunner.query(`ALTER DATABASE ${databaseName} SET search_path TO "$user", public, vectors`);
}
public async down(queryRunner: QueryRunner): Promise<void> {
const databaseName = await queryRunner.query(`SELECT current_database()`);
await queryRunner.query(`ALTER DATABASE ${databaseName} SET search_path TO "$user", public`);
}
}

View file

@ -1,21 +1,60 @@
import { DatabaseExtension, DatabaseLock, IDatabaseRepository, Version } from '@app/domain'; import {
DatabaseExtension,
DatabaseLock,
IDatabaseRepository,
VectorExtension,
VectorIndex,
VectorUpdateResult,
Version,
VersionType,
extName,
} from '@app/domain';
import { vectorExt } from '@app/infra/database.config';
import { Injectable } from '@nestjs/common'; import { Injectable } from '@nestjs/common';
import { InjectDataSource } from '@nestjs/typeorm'; import { InjectDataSource } from '@nestjs/typeorm';
import AsyncLock from 'async-lock'; import AsyncLock from 'async-lock';
import { DataSource, QueryRunner } from 'typeorm'; import { DataSource, EntityManager, QueryRunner } from 'typeorm';
import { isValidInteger } from '../infra.utils';
import { ImmichLogger } from '../logger';
@Injectable() @Injectable()
export class DatabaseRepository implements IDatabaseRepository { export class DatabaseRepository implements IDatabaseRepository {
private logger = new ImmichLogger(DatabaseRepository.name);
readonly asyncLock = new AsyncLock(); readonly asyncLock = new AsyncLock();
constructor(@InjectDataSource() private dataSource: DataSource) {} constructor(@InjectDataSource() private dataSource: DataSource) {}
async getExtensionVersion(extension: DatabaseExtension): Promise<Version | null> { async getExtensionVersion(extension: DatabaseExtension): Promise<Version | null> {
const res = await this.dataSource.query(`SELECT extversion FROM pg_extension WHERE extname = $1`, [extension]); const res = await this.dataSource.query(`SELECT extversion FROM pg_extension WHERE extname = $1`, [extension]);
const version = res[0]?.['extversion']; const extVersion = res[0]?.['extversion'];
if (extVersion == null) {
return null;
}
const version = Version.fromString(extVersion);
if (version.isEqual(new Version(0, 1, 1))) {
return new Version(0, 1, 11);
}
return version;
}
async getAvailableExtensionVersion(extension: DatabaseExtension): Promise<Version | null> {
const res = await this.dataSource.query(
`
SELECT version FROM pg_available_extension_versions
WHERE name = $1 AND installed = false
ORDER BY version DESC`,
[extension],
);
const version = res[0]?.['version'];
return version == null ? null : Version.fromString(version); return version == null ? null : Version.fromString(version);
} }
getPreferredVectorExtension(): VectorExtension {
return vectorExt;
}
async getPostgresVersion(): Promise<Version> { async getPostgresVersion(): Promise<Version> {
const res = await this.dataSource.query(`SHOW server_version`); const res = await this.dataSource.query(`SHOW server_version`);
return Version.fromString(res[0]['server_version']); return Version.fromString(res[0]['server_version']);
@ -25,6 +64,129 @@ export class DatabaseRepository implements IDatabaseRepository {
await this.dataSource.query(`CREATE EXTENSION IF NOT EXISTS ${extension}`); await this.dataSource.query(`CREATE EXTENSION IF NOT EXISTS ${extension}`);
} }
async updateExtension(extension: DatabaseExtension, version?: Version): Promise<void> {
await this.dataSource.query(`ALTER EXTENSION ${extension} UPDATE${version ? ` TO '${version}'` : ''}`);
}
async updateVectorExtension(extension: VectorExtension, version?: Version): Promise<VectorUpdateResult> {
const curVersion = await this.getExtensionVersion(extension);
if (!curVersion) {
throw new Error(`${extName[extension]} extension is not installed`);
}
const minorOrMajor = version && curVersion.isOlderThan(version) >= VersionType.MINOR;
const isVectors = extension === DatabaseExtension.VECTORS;
let restartRequired = false;
await this.dataSource.manager.transaction(async (manager) => {
await this.setSearchPath(manager);
if (minorOrMajor && isVectors) {
await this.updateVectorsSchema(manager, curVersion);
}
await manager.query(`ALTER EXTENSION ${extension} UPDATE${version ? ` TO '${version}'` : ''}`);
if (!minorOrMajor) {
return;
}
if (isVectors) {
await manager.query('SELECT pgvectors_upgrade()');
restartRequired = true;
} else {
await this.reindex(VectorIndex.CLIP);
await this.reindex(VectorIndex.FACE);
}
});
return { restartRequired };
}
async reindex(index: VectorIndex): Promise<void> {
try {
await this.dataSource.query(`REINDEX INDEX ${index}`);
} catch (error) {
if (vectorExt === DatabaseExtension.VECTORS) {
this.logger.warn(`Could not reindex index ${index}. Attempting to auto-fix.`);
const table = index === VectorIndex.CLIP ? 'smart_search' : 'asset_faces';
const dimSize = await this.getDimSize(table);
await this.dataSource.manager.transaction(async (manager) => {
await this.setSearchPath(manager);
await manager.query(`DROP INDEX IF EXISTS ${index}`);
await manager.query(`ALTER TABLE ${table} ALTER COLUMN embedding SET DATA TYPE real[]`);
await manager.query(`ALTER TABLE ${table} ALTER COLUMN embedding SET DATA TYPE vector(${dimSize})`);
await manager.query(`SET vectors.pgvector_compatibility=on`);
await manager.query(`
CREATE INDEX IF NOT EXISTS ${index} ON ${table}
USING hnsw (embedding vector_cosine_ops)
WITH (ef_construction = 300, m = 16)`);
});
} else {
throw error;
}
}
}
async shouldReindex(name: VectorIndex): Promise<boolean> {
if (vectorExt !== DatabaseExtension.VECTORS) {
return false;
}
try {
const res = await this.dataSource.query(
`
SELECT idx_status
FROM pg_vector_index_stat
WHERE indexname = $1`,
[name],
);
return res[0]?.['idx_status'] === 'UPGRADE';
} catch (error) {
const message: string = (error as any).message;
if (message.includes('index is not existing')) {
return true;
} else if (message.includes('relation "pg_vector_index_stat" does not exist')) {
return false;
}
throw error;
}
}
private async setSearchPath(manager: EntityManager): Promise<void> {
await manager.query(`SET search_path TO "$user", public, vectors`);
}
private async updateVectorsSchema(manager: EntityManager, curVersion: Version): Promise<void> {
await manager.query('CREATE SCHEMA IF NOT EXISTS vectors');
await manager.query(`UPDATE pg_catalog.pg_extension SET extversion = $1 WHERE extname = $2`, [
curVersion.toString(),
DatabaseExtension.VECTORS,
]);
await manager.query('UPDATE pg_catalog.pg_extension SET extrelocatable = true WHERE extname = $1', [
DatabaseExtension.VECTORS,
]);
await manager.query('ALTER EXTENSION vectors SET SCHEMA vectors');
await manager.query('UPDATE pg_catalog.pg_extension SET extrelocatable = false WHERE extname = $1', [
DatabaseExtension.VECTORS,
]);
}
private async getDimSize(table: string, column = 'embedding'): Promise<number> {
const res = await this.dataSource.query(`
SELECT atttypmod as dimsize
FROM pg_attribute f
JOIN pg_class c ON c.oid = f.attrelid
WHERE c.relkind = 'r'::char
AND f.attnum > 0
AND c.relname = '${table}'
AND f.attname = '${column}'`);
const dimSize = res[0]['dimsize'];
if (!isValidInteger(dimSize, { min: 1, max: 2 ** 16 })) {
throw new Error(`Could not retrieve dimension size`);
}
return dimSize;
}
async runMigrations(options?: { transaction?: 'all' | 'none' | 'each' }): Promise<void> { async runMigrations(options?: { transaction?: 'all' | 'none' | 'each' }): Promise<void> {
await this.dataSource.runMigrations(options); await this.dataSource.runMigrations(options);
} }

View file

@ -1,10 +1,18 @@
import { Embedding, EmbeddingSearch, FaceEmbeddingSearch, FaceSearchResult, ISmartInfoRepository } from '@app/domain'; import {
DatabaseExtension,
Embedding,
EmbeddingSearch,
FaceEmbeddingSearch,
FaceSearchResult,
ISmartInfoRepository,
} from '@app/domain';
import { getCLIPModelInfo } from '@app/domain/smart-info/smart-info.constant'; import { getCLIPModelInfo } from '@app/domain/smart-info/smart-info.constant';
import { AssetEntity, AssetFaceEntity, SmartInfoEntity, SmartSearchEntity } from '@app/infra/entities'; import { AssetEntity, AssetFaceEntity, SmartInfoEntity, SmartSearchEntity } from '@app/infra/entities';
import { ImmichLogger } from '@app/infra/logger'; import { ImmichLogger } from '@app/infra/logger';
import { Injectable } from '@nestjs/common'; import { Injectable } from '@nestjs/common';
import { InjectRepository } from '@nestjs/typeorm'; import { InjectRepository } from '@nestjs/typeorm';
import { Repository } from 'typeorm'; import { Repository } from 'typeorm';
import { vectorExt } from '../database.config';
import { DummyValue, GenerateSql } from '../infra.util'; import { DummyValue, GenerateSql } from '../infra.util';
import { asVector, isValidInteger } from '../infra.utils'; import { asVector, isValidInteger } from '../infra.utils';
@ -44,16 +52,20 @@ export class SmartInfoRepository implements ISmartInfoRepository {
params: [{ userIds: [DummyValue.UUID], embedding: Array.from({ length: 512 }, Math.random), numResults: 100 }], params: [{ userIds: [DummyValue.UUID], embedding: Array.from({ length: 512 }, Math.random), numResults: 100 }],
}) })
async searchCLIP({ userIds, embedding, numResults, withArchived }: EmbeddingSearch): Promise<AssetEntity[]> { async searchCLIP({ userIds, embedding, numResults, withArchived }: EmbeddingSearch): Promise<AssetEntity[]> {
if (!isValidInteger(numResults, { min: 1 })) {
throw new Error(`Invalid value for 'numResults': ${numResults}`);
}
// setting this too low messes with prefilter recall
numResults = Math.max(numResults, 64);
let results: AssetEntity[] = []; let results: AssetEntity[] = [];
await this.assetRepository.manager.transaction(async (manager) => { await this.assetRepository.manager.transaction(async (manager) => {
await manager.query(`SET LOCAL vectors.enable_prefilter = on`); const query = manager
let query = manager
.createQueryBuilder(AssetEntity, 'a') .createQueryBuilder(AssetEntity, 'a')
.innerJoin('a.smartSearch', 's') .innerJoin('a.smartSearch', 's')
.leftJoinAndSelect('a.exifInfo', 'e') .leftJoinAndSelect('a.exifInfo', 'e')
.where('a.ownerId IN (:...userIds )') .where('a.ownerId IN (:...userIds )')
.orderBy('s.embedding <=> :embedding') .orderBy('s.embedding <=> :embedding')
.setParameters({ userIds, embedding: asVector(embedding) }); .setParameters({ userIds, embedding: asVector(embedding) });
@ -61,15 +73,9 @@ export class SmartInfoRepository implements ISmartInfoRepository {
query.andWhere('a.isArchived = false'); query.andWhere('a.isArchived = false');
} }
query.andWhere('a.isVisible = true').andWhere('a.fileCreatedAt < NOW()'); query.andWhere('a.isVisible = true').andWhere('a.fileCreatedAt < NOW()');
query.limit(numResults);
if (numResults) { await manager.query(this.getRuntimeConfig(numResults));
if (!isValidInteger(numResults, { min: 1 })) {
throw new Error(`Invalid value for 'numResults': ${numResults}`);
}
query = query.limit(numResults);
await manager.query(`SET LOCAL vectors.k = '${numResults}'`);
}
results = await query.getMany(); results = await query.getMany();
}); });
@ -93,36 +99,34 @@ export class SmartInfoRepository implements ISmartInfoRepository {
maxDistance, maxDistance,
hasPerson, hasPerson,
}: FaceEmbeddingSearch): Promise<FaceSearchResult[]> { }: FaceEmbeddingSearch): Promise<FaceSearchResult[]> {
let results: Array<AssetFaceEntity & { distance: number }> = [];
await this.assetRepository.manager.transaction(async (manager) => {
await manager.query(`SET LOCAL vectors.enable_prefilter = on`);
let cte = manager
.createQueryBuilder(AssetFaceEntity, 'faces')
.select('1 + (faces.embedding <=> :embedding)', 'distance')
.innerJoin('faces.asset', 'asset')
.where('asset.ownerId IN (:...userIds )')
.orderBy('1 + (faces.embedding <=> :embedding)')
.setParameters({ userIds, embedding: asVector(embedding) });
if (numResults) {
if (!isValidInteger(numResults, { min: 1 })) { if (!isValidInteger(numResults, { min: 1 })) {
throw new Error(`Invalid value for 'numResults': ${numResults}`); throw new Error(`Invalid value for 'numResults': ${numResults}`);
} }
cte = cte.limit(numResults);
if (numResults > 64) { // setting this too low messes with prefilter recall
// setting k too low messes with prefilter recall numResults = Math.max(numResults, 64);
await manager.query(`SET LOCAL vectors.k = '${numResults}'`);
} let results: Array<AssetFaceEntity & { distance: number }> = [];
} await this.assetRepository.manager.transaction(async (manager) => {
const cte = manager
.createQueryBuilder(AssetFaceEntity, 'faces')
.select('faces.embedding <=> :embedding', 'distance')
.innerJoin('faces.asset', 'asset')
.where('asset.ownerId IN (:...userIds )')
.orderBy('faces.embedding <=> :embedding')
.setParameters({ userIds, embedding: asVector(embedding) });
cte.limit(numResults);
if (hasPerson) { if (hasPerson) {
cte = cte.andWhere('faces."personId" IS NOT NULL'); cte.andWhere('faces."personId" IS NOT NULL');
} }
for (const col of this.faceColumns) { for (const col of this.faceColumns) {
cte.addSelect(`faces.${col}`, col); cte.addSelect(`faces.${col}`, col);
} }
await manager.query(this.getRuntimeConfig(numResults));
results = await manager results = await manager
.createQueryBuilder() .createQueryBuilder()
.select('res.*') .select('res.*')
@ -167,6 +171,9 @@ export class SmartInfoRepository implements ISmartInfoRepository {
this.logger.log(`Updating database CLIP dimension size to ${dimSize}.`); this.logger.log(`Updating database CLIP dimension size to ${dimSize}.`);
await this.smartSearchRepository.manager.transaction(async (manager) => { await this.smartSearchRepository.manager.transaction(async (manager) => {
if (vectorExt === DatabaseExtension.VECTORS) {
await manager.query(`SET vectors.pgvector_compatibility=on`);
}
await manager.query(`DROP TABLE smart_search`); await manager.query(`DROP TABLE smart_search`);
await manager.query(` await manager.query(`
@ -175,12 +182,9 @@ export class SmartInfoRepository implements ISmartInfoRepository {
embedding vector(${dimSize}) NOT NULL )`); embedding vector(${dimSize}) NOT NULL )`);
await manager.query(` await manager.query(`
CREATE INDEX clip_index ON smart_search CREATE INDEX IF NOT EXISTS clip_index ON smart_search
USING vectors (embedding cosine_ops) WITH (options = $$ USING hnsw (embedding vector_cosine_ops)
[indexing.hnsw] WITH (ef_construction = 300, m = 16)`);
m = 16
ef_construction = 300
$$)`);
}); });
this.logger.log(`Successfully updated database CLIP dimension size from ${currentDimSize} to ${dimSize}.`); this.logger.log(`Successfully updated database CLIP dimension size from ${currentDimSize} to ${dimSize}.`);
@ -202,4 +206,17 @@ export class SmartInfoRepository implements ISmartInfoRepository {
} }
return dimSize; return dimSize;
} }
private getRuntimeConfig(numResults?: number): string {
if (vectorExt === DatabaseExtension.VECTOR) {
return 'SET LOCAL hnsw.ef_search = 1000;'; // mitigate post-filter recall
}
let runtimeConfig = 'SET LOCAL vectors.enable_prefilter=on; SET LOCAL vectors.search_mode=vbase;';
if (numResults) {
runtimeConfig += ` SET LOCAL vectors.hnsw_ef_search = ${numResults};`;
}
return runtimeConfig;
}
} }

View file

@ -3,9 +3,13 @@
-- SmartInfoRepository.searchCLIP -- SmartInfoRepository.searchCLIP
START TRANSACTION START TRANSACTION
SET SET
LOCAL vectors.enable_prefilter = on LOCAL vectors.enable_prefilter = on;
SET SET
LOCAL vectors.k = '100' LOCAL vectors.search_mode = vbase;
SET
LOCAL vectors.hnsw_ef_search = 100;
SELECT SELECT
"a"."id" AS "a_id", "a"."id" AS "a_id",
"a"."deviceAssetId" AS "a_deviceAssetId", "a"."deviceAssetId" AS "a_deviceAssetId",
@ -85,9 +89,13 @@ COMMIT
-- SmartInfoRepository.searchFaces -- SmartInfoRepository.searchFaces
START TRANSACTION START TRANSACTION
SET SET
LOCAL vectors.enable_prefilter = on LOCAL vectors.enable_prefilter = on;
SET SET
LOCAL vectors.k = '100' LOCAL vectors.search_mode = vbase;
SET
LOCAL vectors.hnsw_ef_search = 100;
WITH WITH
"cte" AS ( "cte" AS (
SELECT SELECT
@ -100,7 +108,7 @@ WITH
"faces"."boundingBoxY1" AS "boundingBoxY1", "faces"."boundingBoxY1" AS "boundingBoxY1",
"faces"."boundingBoxX2" AS "boundingBoxX2", "faces"."boundingBoxX2" AS "boundingBoxX2",
"faces"."boundingBoxY2" AS "boundingBoxY2", "faces"."boundingBoxY2" AS "boundingBoxY2",
1 + ("faces"."embedding" <= > $1) AS "distance" "faces"."embedding" <= > $1 AS "distance"
FROM FROM
"asset_faces" "faces" "asset_faces" "faces"
INNER JOIN "assets" "asset" ON "asset"."id" = "faces"."assetId" INNER JOIN "assets" "asset" ON "asset"."id" = "faces"."assetId"
@ -108,7 +116,7 @@ WITH
WHERE WHERE
"asset"."ownerId" IN ($2) "asset"."ownerId" IN ($2)
ORDER BY ORDER BY
1 + ("faces"."embedding" <= > $1) ASC "faces"."embedding" <= > $1 ASC
LIMIT LIMIT
100 100
) )

View file

@ -3,8 +3,14 @@ import { IDatabaseRepository, Version } from '@app/domain';
export const newDatabaseRepositoryMock = (): jest.Mocked<IDatabaseRepository> => { export const newDatabaseRepositoryMock = (): jest.Mocked<IDatabaseRepository> => {
return { return {
getExtensionVersion: jest.fn(), getExtensionVersion: jest.fn(),
getAvailableExtensionVersion: jest.fn(),
getPreferredVectorExtension: jest.fn(),
getPostgresVersion: jest.fn().mockResolvedValue(new Version(14, 0, 0)), getPostgresVersion: jest.fn().mockResolvedValue(new Version(14, 0, 0)),
createExtension: jest.fn().mockImplementation(() => Promise.resolve()), createExtension: jest.fn().mockImplementation(() => Promise.resolve()),
updateExtension: jest.fn(),
updateVectorExtension: jest.fn(),
reindex: jest.fn(),
shouldReindex: jest.fn(),
runMigrations: jest.fn(), runMigrations: jest.fn(),
withLock: jest.fn().mockImplementation((_, function_: <R>() => Promise<R>) => function_()), withLock: jest.fn().mockImplementation((_, function_: <R>() => Promise<R>) => function_()),
isBusy: jest.fn(), isBusy: jest.fn(),