1
0
Fork 0
mirror of https://github.com/immich-app/immich.git synced 2024-12-28 22:51:59 +00:00

feat(server)!: pgvecto.rs 0.2 and pgvector compatibility (#6785)

* basic changes

update version check

set ef_search for clip

* pgvector compatibility

Revert "pgvector compatibility"

This reverts commit 2b66a52aa4097dd27da58138c5288fd87cb9b24a.

pgvector compatibility: minimal edition

pgvector startup check

* update extension at startup

* wording

shortened vector extension variable name

* nightly docker

* fixed version checks

* update tests

add tests for updating extension

remove unnecessary check

* simplify `getRuntimeConfig`

* wording

* reindex on minor version update

* 0.2 upgrade testing

update prod compose

* acquire lock for init

* wip vector down on shutdown

* use upgrade helper

* update image tag

* refine restart check

check error message

* test reindex

testing

upstream fix

formatting

fixed reindexing

* use enum in signature

* fix tests

remove unused code

* add reindexing tests

* update to official 0.2

remove alpha from version name

* add warning test if restart required

* update test image to 0.2.0

* linting and test cleanup

* formatting

* update sql

* wording

* handle setting search path for new and existing databases

* handle new db in reindex check

* fix post-update reindexing

* get dim size

* formatting

* use vbase

* handle different db name

* update sql

* linting

* fix suggested env
This commit is contained in:
Mert 2024-02-06 21:46:38 -05:00 committed by GitHub
parent b2775c445a
commit 56b0643890
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
25 changed files with 650 additions and 246 deletions

View file

@ -275,7 +275,7 @@ jobs:
runs-on: ubuntu-latest
services:
postgres:
image: tensorchord/pgvecto-rs:pg14-v0.1.11@sha256:0335a1a22f8c5dd1b697f14f079934f5152eaaa216c09b61e293be285491f8ee
image: tensorchord/pgvecto-rs:pg14-v0.2.0
env:
POSTGRES_PASSWORD: postgres
POSTGRES_USER: postgres

View file

@ -25,7 +25,7 @@ export default async () => {
if (process.env.DB_HOSTNAME === undefined) {
// DB hostname not set which likely means we're not running e2e through docker compose. Start a local postgres container.
const pg = await new PostgreSqlContainer('tensorchord/pgvecto-rs:pg14-v0.1.11')
const pg = await new PostgreSqlContainer('tensorchord/pgvecto-rs:pg14-v0.2.0')
.withExposedPorts(5432)
.withDatabase('immich')
.withUsername('postgres')

View file

@ -103,7 +103,7 @@ services:
database:
container_name: immich_postgres
image: tensorchord/pgvecto-rs:pg14-v0.1.11@sha256:0335a1a22f8c5dd1b697f14f079934f5152eaaa216c09b61e293be285491f8ee
image: tensorchord/pgvecto-rs:pg14-v0.2.0
env_file:
- .env
environment:

View file

@ -61,7 +61,7 @@ services:
database:
container_name: immich_postgres
image: tensorchord/pgvecto-rs:pg14-v0.1.11@sha256:0335a1a22f8c5dd1b697f14f079934f5152eaaa216c09b61e293be285491f8ee
image: tensorchord/pgvecto-rs:pg14-v0.2.0
env_file:
- .env
environment:
@ -70,7 +70,8 @@ services:
POSTGRES_DB: ${DB_DATABASE_NAME}
volumes:
- ${UPLOAD_LOCATION}/postgres:/var/lib/postgresql/data
restart: always
ports:
- 5432:5432
volumes:
model-cache:

View file

@ -65,7 +65,7 @@ services:
database:
container_name: immich_postgres
image: tensorchord/pgvecto-rs:pg14-v0.1.11@sha256:0335a1a22f8c5dd1b697f14f079934f5152eaaa216c09b61e293be285491f8ee
image: tensorchord/pgvecto-rs:pg14-v0.2.0
env_file:
- .env
environment:

View file

@ -1,7 +1,7 @@
import { PostgreSqlContainer } from '@testcontainers/postgresql';
export default async () => {
const pg = await new PostgreSqlContainer('tensorchord/pgvecto-rs:pg14-v0.1.11')
const pg = await new PostgreSqlContainer('tensorchord/pgvecto-rs:pg14-v0.2.0')
.withDatabase('immich')
.withUsername('postgres')
.withPassword('postgres')

View file

@ -21,7 +21,7 @@ services:
- database
database:
image: tensorchord/pgvecto-rs:pg14-v0.1.11@sha256:0335a1a22f8c5dd1b697f14f079934f5152eaaa216c09b61e293be285491f8ee
image: tensorchord/pgvecto-rs:pg14-v0.2.0
command: -c fsync=off -c shared_preload_libraries=vectors.so
environment:
POSTGRES_PASSWORD: postgres

View file

@ -25,7 +25,7 @@ export default async () => {
if (process.env.DB_HOSTNAME === undefined) {
// DB hostname not set which likely means we're not running e2e through docker compose. Start a local postgres container.
const pg = await new PostgreSqlContainer('tensorchord/pgvecto-rs:pg14-v0.1.11')
const pg = await new PostgreSqlContainer('tensorchord/pgvecto-rs:pg14-v0.2.0')
.withExposedPorts(5432)
.withDatabase('immich')
.withUsername('postgres')

View file

@ -1,41 +1,65 @@
import { DatabaseExtension, DatabaseService, IDatabaseRepository, Version } from '@app/domain';
import {
DatabaseExtension,
DatabaseService,
IDatabaseRepository,
VectorIndex,
Version,
VersionType,
} from '@app/domain';
import { ImmichLogger } from '@app/infra/logger';
import { newDatabaseRepositoryMock } from '@test';
describe(DatabaseService.name, () => {
let sut: DatabaseService;
let databaseMock: jest.Mocked<IDatabaseRepository>;
let fatalLog: jest.SpyInstance;
beforeEach(async () => {
databaseMock = newDatabaseRepositoryMock();
fatalLog = jest.spyOn(ImmichLogger.prototype, 'fatal');
sut = new DatabaseService(databaseMock);
sut.minVectorsVersion = new Version(0, 1, 1);
sut.maxVectorsVersion = new Version(0, 1, 11);
});
afterEach(() => {
fatalLog.mockRestore();
});
it('should work', () => {
expect(sut).toBeDefined();
});
describe('init', () => {
it('should resolve successfully if minimum supported PostgreSQL and vectors version are installed', async () => {
describe.each([
[{ vectorExt: DatabaseExtension.VECTORS, extName: 'pgvecto.rs', minVersion: new Version(0, 1, 1) }],
[{ vectorExt: DatabaseExtension.VECTOR, extName: 'pgvector', minVersion: new Version(0, 5, 0) }],
] as const)('init', ({ vectorExt, extName, minVersion }) => {
let fatalLog: jest.SpyInstance;
let errorLog: jest.SpyInstance;
let warnLog: jest.SpyInstance;
beforeEach(async () => {
fatalLog = jest.spyOn(ImmichLogger.prototype, 'fatal');
errorLog = jest.spyOn(ImmichLogger.prototype, 'error');
warnLog = jest.spyOn(ImmichLogger.prototype, 'warn');
databaseMock.getPreferredVectorExtension.mockReturnValue(vectorExt);
databaseMock.getExtensionVersion.mockResolvedValue(minVersion);
sut = new DatabaseService(databaseMock);
sut.minVectorVersion = minVersion;
sut.minVectorsVersion = minVersion;
sut.vectorVersionPin = VersionType.MINOR;
sut.vectorsVersionPin = VersionType.MINOR;
});
afterEach(() => {
fatalLog.mockRestore();
warnLog.mockRestore();
});
it(`should resolve successfully if minimum supported PostgreSQL and ${extName} version are installed`, async () => {
databaseMock.getPostgresVersion.mockResolvedValueOnce(new Version(14, 0, 0));
databaseMock.getExtensionVersion.mockResolvedValueOnce(new Version(0, 1, 1));
await expect(sut.init()).resolves.toBeUndefined();
expect(databaseMock.getPostgresVersion).toHaveBeenCalledTimes(2);
expect(databaseMock.createExtension).toHaveBeenCalledWith(DatabaseExtension.VECTORS);
expect(databaseMock.getPostgresVersion).toHaveBeenCalled();
expect(databaseMock.createExtension).toHaveBeenCalledWith(vectorExt);
expect(databaseMock.createExtension).toHaveBeenCalledTimes(1);
expect(databaseMock.getExtensionVersion).toHaveBeenCalledTimes(1);
expect(databaseMock.getExtensionVersion).toHaveBeenCalled();
expect(databaseMock.runMigrations).toHaveBeenCalledTimes(1);
expect(fatalLog).not.toHaveBeenCalled();
});
@ -43,112 +67,162 @@ describe(DatabaseService.name, () => {
it('should throw an error if PostgreSQL version is below minimum supported version', async () => {
databaseMock.getPostgresVersion.mockResolvedValueOnce(new Version(13, 0, 0));
await expect(sut.init()).rejects.toThrow(/PostgreSQL version is 13/s);
await expect(sut.init()).rejects.toThrow('PostgreSQL version is 13');
expect(databaseMock.getPostgresVersion).toHaveBeenCalledTimes(1);
});
it('should resolve successfully if minimum supported vectors version is installed', async () => {
databaseMock.getExtensionVersion.mockResolvedValueOnce(new Version(0, 1, 1));
it(`should resolve successfully if minimum supported ${extName} version is installed`, async () => {
await expect(sut.init()).resolves.toBeUndefined();
expect(databaseMock.createExtension).toHaveBeenCalledWith(DatabaseExtension.VECTORS);
expect(databaseMock.createExtension).toHaveBeenCalledWith(vectorExt);
expect(databaseMock.createExtension).toHaveBeenCalledTimes(1);
expect(databaseMock.getExtensionVersion).toHaveBeenCalledTimes(1);
expect(databaseMock.runMigrations).toHaveBeenCalledTimes(1);
expect(fatalLog).not.toHaveBeenCalled();
});
it('should resolve successfully if maximum supported vectors version is installed', async () => {
databaseMock.getExtensionVersion.mockResolvedValueOnce(new Version(0, 1, 11));
it(`should throw an error if ${extName} version is not installed even after createVectorExtension`, async () => {
databaseMock.getExtensionVersion.mockResolvedValue(null);
await expect(sut.init()).resolves.toBeUndefined();
await expect(sut.init()).rejects.toThrow(`Unexpected: ${extName} extension is not installed.`);
expect(databaseMock.createExtension).toHaveBeenCalledWith(DatabaseExtension.VECTORS);
expect(databaseMock.createExtension).toHaveBeenCalledTimes(1);
expect(databaseMock.getExtensionVersion).toHaveBeenCalledTimes(1);
expect(databaseMock.runMigrations).toHaveBeenCalledTimes(1);
expect(fatalLog).not.toHaveBeenCalled();
});
it('should throw an error if vectors version is not installed even after createVectors', async () => {
databaseMock.getExtensionVersion.mockResolvedValueOnce(null);
await expect(sut.init()).rejects.toThrow('Unexpected: The pgvecto.rs extension is not installed.');
expect(databaseMock.getExtensionVersion).toHaveBeenCalledTimes(1);
expect(databaseMock.createExtension).toHaveBeenCalledTimes(1);
expect(databaseMock.runMigrations).not.toHaveBeenCalled();
});
it('should throw an error if vectors version is below minimum supported version', async () => {
databaseMock.getExtensionVersion.mockResolvedValueOnce(new Version(0, 0, 1));
await expect(sut.init()).rejects.toThrow(/('tensorchord\/pgvecto-rs:pg14-v0.1.11')/s);
expect(databaseMock.getExtensionVersion).toHaveBeenCalledTimes(1);
expect(databaseMock.runMigrations).not.toHaveBeenCalled();
});
it('should throw an error if vectors version is above maximum supported version', async () => {
databaseMock.getExtensionVersion.mockResolvedValueOnce(new Version(0, 1, 12));
await expect(sut.init()).rejects.toThrow(
/('DROP EXTENSION IF EXISTS vectors').*('tensorchord\/pgvecto-rs:pg14-v0\.1\.11')/s,
it(`should throw an error if ${extName} version is below minimum supported version`, async () => {
databaseMock.getExtensionVersion.mockResolvedValue(
new Version(minVersion.major, minVersion.minor - 1, minVersion.patch),
);
expect(databaseMock.getExtensionVersion).toHaveBeenCalledTimes(1);
await expect(sut.init()).rejects.toThrow(extName);
expect(databaseMock.runMigrations).not.toHaveBeenCalled();
});
it('should throw an error if vectors version is a nightly', async () => {
databaseMock.getExtensionVersion.mockResolvedValueOnce(new Version(0, 0, 0));
it.each([
{ type: VersionType.EQUAL, max: 'no', actual: 'patch' },
{ type: VersionType.PATCH, max: 'patch', actual: 'minor' },
{ type: VersionType.MINOR, max: 'minor', actual: 'major' },
] as const)(
`should throw an error if $max upgrade from min version is allowed and ${extName} version is $actual`,
async ({ type, actual }) => {
const version = new Version(minVersion.major, minVersion.minor, minVersion.patch);
version[actual] = minVersion[actual] + 1;
databaseMock.getExtensionVersion.mockResolvedValue(version);
if (vectorExt === DatabaseExtension.VECTOR) {
sut.minVectorVersion = minVersion;
sut.vectorVersionPin = type;
} else {
sut.minVectorsVersion = minVersion;
sut.vectorsVersionPin = type;
}
await expect(sut.init()).rejects.toThrow(
/(nightly).*('DROP EXTENSION IF EXISTS vectors').*('tensorchord\/pgvecto-rs:pg14-v0\.1\.11')/s,
);
await expect(sut.init()).rejects.toThrow(extName);
expect(databaseMock.runMigrations).not.toHaveBeenCalled();
},
);
it(`should throw an error if ${extName} version is a nightly`, async () => {
databaseMock.getExtensionVersion.mockResolvedValue(new Version(0, 0, 0));
await expect(sut.init()).rejects.toThrow(extName);
expect(databaseMock.getExtensionVersion).toHaveBeenCalledTimes(1);
expect(databaseMock.createExtension).toHaveBeenCalledTimes(1);
expect(databaseMock.runMigrations).not.toHaveBeenCalled();
});
it('should throw error if vectors extension could not be created', async () => {
databaseMock.createExtension.mockRejectedValueOnce(new Error('Failed to create extension'));
it(`should throw error if ${extName} extension could not be created`, async () => {
databaseMock.createExtension.mockRejectedValue(new Error('Failed to create extension'));
await expect(sut.init()).rejects.toThrow('Failed to create extension');
expect(fatalLog).toHaveBeenCalledTimes(1);
expect(fatalLog.mock.calls[0][0]).toMatch(/('tensorchord\/pgvecto-rs:pg14-v0\.1\.11').*(v1\.91\.0)/s);
expect(databaseMock.createExtension).toHaveBeenCalledTimes(1);
expect(databaseMock.runMigrations).not.toHaveBeenCalled();
});
it.each([{ major: 14 }, { major: 15 }, { major: 16 }])(
`should suggest image with postgres $major if database is $major`,
async ({ major }) => {
databaseMock.getExtensionVersion.mockResolvedValue(new Version(0, 0, 1));
databaseMock.getPostgresVersion.mockResolvedValue(new Version(major, 0, 0));
it(`should update ${extName} if a newer version is available`, async () => {
const version = new Version(minVersion.major, minVersion.minor + 1, minVersion.patch);
databaseMock.getAvailableExtensionVersion.mockResolvedValue(version);
await expect(sut.init()).rejects.toThrow(new RegExp(`tensorchord\/pgvecto-rs:pg${major}-v0\\.1\\.11`, 's'));
await expect(sut.init()).resolves.toBeUndefined();
expect(databaseMock.updateVectorExtension).toHaveBeenCalledWith(vectorExt, version);
expect(databaseMock.updateVectorExtension).toHaveBeenCalledTimes(1);
expect(databaseMock.runMigrations).toHaveBeenCalledTimes(1);
expect(fatalLog).not.toHaveBeenCalled();
});
it(`should not update ${extName} if a newer version is higher than the maximum`, async () => {
const version = new Version(minVersion.major + 1, minVersion.minor, minVersion.patch);
databaseMock.getAvailableExtensionVersion.mockResolvedValue(version);
await expect(sut.init()).resolves.toBeUndefined();
expect(databaseMock.updateVectorExtension).not.toHaveBeenCalled();
expect(databaseMock.runMigrations).toHaveBeenCalledTimes(1);
expect(fatalLog).not.toHaveBeenCalled();
});
it(`should warn if attempted to update ${extName} and failed`, async () => {
const version = new Version(minVersion.major, minVersion.minor, minVersion.patch + 1);
databaseMock.getAvailableExtensionVersion.mockResolvedValue(version);
databaseMock.updateVectorExtension.mockRejectedValue(new Error('Failed to update extension'));
await expect(sut.init()).resolves.toBeUndefined();
expect(warnLog).toHaveBeenCalledTimes(1);
expect(warnLog.mock.calls[0][0]).toContain(extName);
expect(errorLog).toHaveBeenCalledTimes(1);
expect(fatalLog).not.toHaveBeenCalled();
expect(databaseMock.updateVectorExtension).toHaveBeenCalledWith(vectorExt, version);
expect(databaseMock.runMigrations).toHaveBeenCalledTimes(1);
});
it(`should warn if ${extName} update requires restart`, async () => {
const version = new Version(minVersion.major, minVersion.minor, minVersion.patch + 1);
databaseMock.getAvailableExtensionVersion.mockResolvedValue(version);
databaseMock.updateVectorExtension.mockResolvedValue({ restartRequired: true });
await expect(sut.init()).resolves.toBeUndefined();
expect(warnLog).toHaveBeenCalledTimes(1);
expect(warnLog.mock.calls[0][0]).toContain(extName);
expect(databaseMock.updateVectorExtension).toHaveBeenCalledWith(vectorExt, version);
expect(databaseMock.runMigrations).toHaveBeenCalledTimes(1);
expect(fatalLog).not.toHaveBeenCalled();
});
it.each([{ index: VectorIndex.CLIP }, { index: VectorIndex.FACE }])(
`should reindex $index if necessary`,
async ({ index }) => {
databaseMock.shouldReindex.mockImplementation((indexArg) => Promise.resolve(indexArg === index));
await expect(sut.init()).resolves.toBeUndefined();
expect(databaseMock.shouldReindex).toHaveBeenCalledWith(index);
expect(databaseMock.shouldReindex).toHaveBeenCalledTimes(2);
expect(databaseMock.reindex).toHaveBeenCalledWith(index);
expect(databaseMock.reindex).toHaveBeenCalledTimes(1);
expect(databaseMock.runMigrations).toHaveBeenCalledTimes(1);
expect(fatalLog).not.toHaveBeenCalled();
},
);
it('should not suggest image if postgres version is not in 14, 15 or 16', async () => {
databaseMock.getPostgresVersion.mockResolvedValueOnce(new Version(17, 0, 0));
databaseMock.getPostgresVersion.mockResolvedValueOnce(new Version(17, 0, 0));
it.each([{ index: VectorIndex.CLIP }, { index: VectorIndex.FACE }])(
`should not reindex $index if not necessary`,
async () => {
databaseMock.shouldReindex.mockResolvedValue(false);
await expect(sut.init()).rejects.toThrow(/^(?:(?!tensorchord\/pgvecto-rs).)*$/s);
});
await expect(sut.init()).resolves.toBeUndefined();
it('should reject and suggest the maximum supported version when unsupported pgvecto.rs version is in use', async () => {
databaseMock.getExtensionVersion.mockResolvedValue(new Version(0, 0, 1));
await expect(sut.init()).rejects.toThrow(/('tensorchord\/pgvecto-rs:pg14-v0\.1\.11')/s);
sut.maxVectorsVersion = new Version(0, 1, 12);
await expect(sut.init()).rejects.toThrow(/('tensorchord\/pgvecto-rs:pg14-v0\.1\.12')/s);
});
expect(databaseMock.shouldReindex).toHaveBeenCalledTimes(2);
expect(databaseMock.reindex).not.toHaveBeenCalled();
expect(databaseMock.runMigrations).toHaveBeenCalledTimes(1);
expect(fatalLog).not.toHaveBeenCalled();
},
);
});
});

View file

@ -1,74 +1,56 @@
import { ImmichLogger } from '@app/infra/logger';
import { Inject, Injectable } from '@nestjs/common';
import { QueryFailedError } from 'typeorm';
import { Version } from '../domain.constant';
import { DatabaseExtension, IDatabaseRepository } from '../repositories';
import { Version, VersionType } from '../domain.constant';
import {
DatabaseExtension,
DatabaseLock,
IDatabaseRepository,
VectorExtension,
VectorIndex,
extName,
} from '../repositories';
@Injectable()
export class DatabaseService {
private logger = new ImmichLogger(DatabaseService.name);
private vectorExt: VectorExtension;
minPostgresVersion = 14;
minVectorsVersion = new Version(0, 1, 1);
maxVectorsVersion = new Version(0, 1, 11);
minVectorsVersion = new Version(0, 2, 0);
vectorsVersionPin = VersionType.MINOR;
minVectorVersion = new Version(0, 5, 0);
vectorVersionPin = VersionType.MAJOR;
constructor(@Inject(IDatabaseRepository) private databaseRepository: IDatabaseRepository) {}
constructor(@Inject(IDatabaseRepository) private databaseRepository: IDatabaseRepository) {
this.vectorExt = this.databaseRepository.getPreferredVectorExtension();
}
async init() {
await this.assertPostgresql();
await this.createVectors();
await this.assertVectors();
await this.databaseRepository.runMigrations();
}
await this.databaseRepository.withLock(DatabaseLock.Migrations, async () => {
await this.createVectorExtension();
await this.updateVectorExtension();
await this.assertVectorExtension();
private async assertVectors() {
const version = await this.databaseRepository.getExtensionVersion(DatabaseExtension.VECTORS);
if (version == null) {
throw new Error('Unexpected: The pgvecto.rs extension is not installed.');
}
try {
if (await this.databaseRepository.shouldReindex(VectorIndex.CLIP)) {
await this.databaseRepository.reindex(VectorIndex.CLIP);
}
const image = await this.getVectorsImage();
const suggestion = image ? `, such as with the docker image '${image}'` : '';
if (await this.databaseRepository.shouldReindex(VectorIndex.FACE)) {
await this.databaseRepository.reindex(VectorIndex.FACE);
}
} catch (error) {
this.logger.warn(
'Could not run vector reindexing checks. If the extension was updated, please restart the Postgres instance.',
);
throw error;
}
if (version.isEqual(new Version(0, 0, 0))) {
throw new Error(
`The pgvecto.rs extension version is ${version}, which means it is a nightly release.` +
`Please run 'DROP EXTENSION IF EXISTS vectors' and switch to a release version${suggestion}.`,
);
}
if (version.isNewerThan(this.maxVectorsVersion)) {
throw new Error(`
The pgvecto.rs extension version is ${version} instead of ${this.maxVectorsVersion}.
Please run 'DROP EXTENSION IF EXISTS vectors' and switch to ${this.maxVectorsVersion}${suggestion}.`);
}
if (version.isOlderThan(this.minVectorsVersion)) {
throw new Error(`
The pgvecto.rs extension version is ${version}, which is older than the minimum supported version ${this.minVectorsVersion}.
Please upgrade to this version or later${suggestion}.`);
}
}
private async createVectors() {
await this.databaseRepository.createExtension(DatabaseExtension.VECTORS).catch(async (error: QueryFailedError) => {
const image = await this.getVectorsImage();
this.logger.fatal(`
Failed to create pgvecto.rs extension.
If you have not updated your Postgres instance to a docker image that supports pgvecto.rs (such as '${image}'), please do so.
See the v1.91.0 release notes for more info: https://github.com/immich-app/immich/releases/tag/v1.91.0'
`);
throw error;
await this.databaseRepository.runMigrations();
});
}
private async getVectorsImage() {
const { major } = await this.databaseRepository.getPostgresVersion();
if (![14, 15, 16].includes(major)) {
return null;
}
return `tensorchord/pgvecto-rs:pg${major}-v${this.maxVectorsVersion}`;
}
private async assertPostgresql() {
const { major } = await this.databaseRepository.getPostgresVersion();
if (major < this.minPostgresVersion) {
@ -77,4 +59,99 @@ export class DatabaseService {
Please upgrade to this version or later.`);
}
}
private async createVectorExtension() {
await this.databaseRepository.createExtension(this.vectorExt).catch(async (error: QueryFailedError) => {
const otherExt =
this.vectorExt === DatabaseExtension.VECTORS ? DatabaseExtension.VECTOR : DatabaseExtension.VECTORS;
this.logger.fatal(`
Failed to activate ${extName[this.vectorExt]} extension.
Please ensure the Postgres instance has ${extName[this.vectorExt]} installed.
If the Postgres instance already has ${extName[this.vectorExt]} installed, Immich may not have the necessary permissions to activate it.
In this case, please run 'CREATE EXTENSION IF NOT EXISTS ${this.vectorExt}' manually as a superuser.
See https://immich.app/docs/guides/database-queries for how to query the database.
Alternatively, if your Postgres instance has ${extName[otherExt]}, you may use this instead by setting the environment variable 'VECTOR_EXTENSION=${otherExt}'.
Note that switching between the two extensions after a successful startup is not supported.
The exception is if your version of Immich prior to upgrading was 1.90.2 or earlier.
In this case, you may set either extension now, but you will not be able to switch to the other extension following a successful startup.
`);
throw error;
});
}
private async updateVectorExtension() {
const [version, availableVersion] = await Promise.all([
this.databaseRepository.getExtensionVersion(this.vectorExt),
this.databaseRepository.getAvailableExtensionVersion(this.vectorExt),
]);
if (version == null) {
throw new Error(`Unexpected: ${extName[this.vectorExt]} extension is not installed.`);
}
if (availableVersion == null) {
return;
}
const maxVersion = this.vectorExt === DatabaseExtension.VECTOR ? this.vectorVersionPin : this.vectorsVersionPin;
const isNewer = availableVersion.isNewerThan(version);
if (isNewer == null || isNewer > maxVersion) {
return;
}
try {
this.logger.log(`Updating ${extName[this.vectorExt]} extension to ${availableVersion}`);
const { restartRequired } = await this.databaseRepository.updateVectorExtension(this.vectorExt, availableVersion);
if (restartRequired) {
this.logger.warn(`
The ${extName[this.vectorExt]} extension has been updated to ${availableVersion}.
Please restart the Postgres instance to complete the update.`);
}
} catch (error) {
this.logger.warn(`
The ${extName[this.vectorExt]} extension version is ${version}, but ${availableVersion} is available.
Immich attempted to update the extension, but failed to do so.
This may be because Immich does not have the necessary permissions to update the extension.
Please run 'ALTER EXTENSION ${this.vectorExt} UPDATE' manually as a superuser.
See https://immich.app/docs/guides/database-queries for how to query the database.`);
this.logger.error(error);
}
}
private async assertVectorExtension() {
const version = await this.databaseRepository.getExtensionVersion(this.vectorExt);
if (version == null) {
throw new Error(`Unexpected: The ${extName[this.vectorExt]} extension is not installed.`);
}
if (version.isEqual(new Version(0, 0, 0))) {
throw new Error(`
The ${extName[this.vectorExt]} extension version is ${version}, which means it is a nightly release.
Please run 'DROP EXTENSION IF EXISTS ${this.vectorExt}' and switch to a release version.
See https://immich.app/docs/guides/database-queries for how to query the database.`);
}
const minVersion = this.vectorExt === DatabaseExtension.VECTOR ? this.minVectorVersion : this.minVectorsVersion;
const maxVersion = this.vectorExt === DatabaseExtension.VECTOR ? this.vectorVersionPin : this.vectorsVersionPin;
if (version.isOlderThan(minVersion) || version.isNewerThan(minVersion) > maxVersion) {
const allowedReleaseType = maxVersion === VersionType.MAJOR ? '' : ` ${VersionType[maxVersion].toLowerCase()}`;
const releases =
maxVersion === VersionType.EQUAL
? minVersion.toString()
: `${minVersion} and later${allowedReleaseType} releases`;
throw new Error(`
The ${extName[this.vectorExt]} extension version is ${version}, but Immich only supports ${releases}.
If the Postgres instance already has a compatible version installed, Immich may not have the necessary permissions to activate it.
In this case, please run 'ALTER EXTENSION UPDATE ${this.vectorExt}' manually as a superuser.
See https://immich.app/docs/guides/database-queries for how to query the database.
Otherwise, please update the version of ${extName[this.vectorExt]} in the Postgres instance to a compatible version.`);
}
}
}

View file

@ -24,5 +24,6 @@ export const immichAppConfig: ConfigModuleOptions = {
MACHINE_LEARNING_PORT: Joi.number().optional(),
MICROSERVICES_PORT: Joi.number().optional(),
SERVER_PORT: Joi.number().optional(),
VECTOR_EXTENSION: Joi.string().optional().valid('pgvector', 'pgvecto.rs').default('pgvecto.rs'),
}),
};

View file

@ -1,4 +1,4 @@
import { Version, mimeTypes } from './domain.constant';
import { Version, VersionType, mimeTypes } from './domain.constant';
describe('mimeTypes', () => {
for (const { mimetype, extension } of [
@ -196,45 +196,37 @@ describe('mimeTypes', () => {
});
});
describe('ServerVersion', () => {
describe('Version', () => {
const tests = [
{ this: new Version(0, 0, 1), other: new Version(0, 0, 0), expected: 1 },
{ this: new Version(0, 1, 0), other: new Version(0, 0, 0), expected: 1 },
{ this: new Version(1, 0, 0), other: new Version(0, 0, 0), expected: 1 },
{ this: new Version(0, 0, 0), other: new Version(0, 0, 1), expected: -1 },
{ this: new Version(0, 0, 0), other: new Version(0, 1, 0), expected: -1 },
{ this: new Version(0, 0, 0), other: new Version(1, 0, 0), expected: -1 },
{ this: new Version(0, 0, 0), other: new Version(0, 0, 0), expected: 0 },
{ this: new Version(0, 0, 1), other: new Version(0, 0, 1), expected: 0 },
{ this: new Version(0, 1, 0), other: new Version(0, 1, 0), expected: 0 },
{ this: new Version(1, 0, 0), other: new Version(1, 0, 0), expected: 0 },
{ this: new Version(1, 0), other: new Version(1, 0, 0), expected: 0 },
{ this: new Version(1, 0), other: new Version(1, 0, 1), expected: -1 },
{ this: new Version(1, 1), other: new Version(1, 0, 1), expected: 1 },
{ this: new Version(1), other: new Version(1, 0, 0), expected: 0 },
{ this: new Version(1), other: new Version(1, 0, 1), expected: -1 },
{ this: new Version(0, 0, 1), other: new Version(0, 0, 0), compare: 1, type: VersionType.PATCH },
{ this: new Version(0, 1, 0), other: new Version(0, 0, 0), compare: 1, type: VersionType.MINOR },
{ this: new Version(1, 0, 0), other: new Version(0, 0, 0), compare: 1, type: VersionType.MAJOR },
{ this: new Version(0, 0, 0), other: new Version(0, 0, 1), compare: -1, type: VersionType.PATCH },
{ this: new Version(0, 0, 0), other: new Version(0, 1, 0), compare: -1, type: VersionType.MINOR },
{ this: new Version(0, 0, 0), other: new Version(1, 0, 0), compare: -1, type: VersionType.MAJOR },
{ this: new Version(0, 0, 0), other: new Version(0, 0, 0), compare: 0, type: VersionType.EQUAL },
{ this: new Version(0, 0, 1), other: new Version(0, 0, 1), compare: 0, type: VersionType.EQUAL },
{ this: new Version(0, 1, 0), other: new Version(0, 1, 0), compare: 0, type: VersionType.EQUAL },
{ this: new Version(1, 0, 0), other: new Version(1, 0, 0), compare: 0, type: VersionType.EQUAL },
{ this: new Version(1, 0), other: new Version(1, 0, 0), compare: 0, type: VersionType.EQUAL },
{ this: new Version(1, 0), other: new Version(1, 0, 1), compare: -1, type: VersionType.PATCH },
{ this: new Version(1, 1), other: new Version(1, 0, 1), compare: 1, type: VersionType.MINOR },
{ this: new Version(1), other: new Version(1, 0, 0), compare: 0, type: VersionType.EQUAL },
{ this: new Version(1), other: new Version(1, 0, 1), compare: -1, type: VersionType.PATCH },
];
describe('compare', () => {
for (const { this: thisVersion, other: otherVersion, expected } of tests) {
it(`should return ${expected} when comparing ${thisVersion} to ${otherVersion}`, () => {
expect(thisVersion.compare(otherVersion)).toEqual(expected);
});
}
});
describe('isOlderThan', () => {
for (const { this: thisVersion, other: otherVersion, expected } of tests) {
const bool = expected < 0;
it(`should return ${bool} when comparing ${thisVersion} to ${otherVersion}`, () => {
expect(thisVersion.isOlderThan(otherVersion)).toEqual(bool);
for (const { this: thisVersion, other: otherVersion, compare, type } of tests) {
const expected = compare < 0 ? type : VersionType.EQUAL;
it(`should return '${expected}' when comparing ${thisVersion} to ${otherVersion}`, () => {
expect(thisVersion.isOlderThan(otherVersion)).toEqual(expected);
});
}
});
describe('isEqual', () => {
for (const { this: thisVersion, other: otherVersion, expected } of tests) {
const bool = expected === 0;
for (const { this: thisVersion, other: otherVersion, compare } of tests) {
const bool = compare === 0;
it(`should return ${bool} when comparing ${thisVersion} to ${otherVersion}`, () => {
expect(thisVersion.isEqual(otherVersion)).toEqual(bool);
});
@ -242,10 +234,10 @@ describe('ServerVersion', () => {
});
describe('isNewerThan', () => {
for (const { this: thisVersion, other: otherVersion, expected } of tests) {
const bool = expected > 0;
it(`should return ${bool} when comparing ${thisVersion} to ${otherVersion}`, () => {
expect(thisVersion.isNewerThan(otherVersion)).toEqual(bool);
for (const { this: thisVersion, other: otherVersion, compare, type } of tests) {
const expected = compare > 0 ? type : VersionType.EQUAL;
it(`should return ${expected} when comparing ${thisVersion} to ${otherVersion}`, () => {
expect(thisVersion.isNewerThan(otherVersion)).toEqual(expected);
});
}
});

View file

@ -12,11 +12,20 @@ export interface IVersion {
patch: number;
}
export enum VersionType {
EQUAL = 0,
PATCH = 1,
MINOR = 2,
MAJOR = 3,
}
export class Version implements IVersion {
public readonly types = ['major', 'minor', 'patch'] as const;
constructor(
public readonly major: number,
public readonly minor: number = 0,
public readonly patch: number = 0,
public major: number,
public minor: number = 0,
public patch: number = 0,
) {}
toString() {
@ -39,27 +48,30 @@ export class Version implements IVersion {
}
}
compare(version: Version): number {
for (const key of ['major', 'minor', 'patch'] as const) {
private compare(version: Version): [number, VersionType] {
for (const [i, key] of this.types.entries()) {
const diff = this[key] - version[key];
if (diff !== 0) {
return diff > 0 ? 1 : -1;
return [diff > 0 ? 1 : -1, (VersionType.MAJOR - i) as VersionType];
}
}
return 0;
return [0, VersionType.EQUAL];
}
isOlderThan(version: Version): boolean {
return this.compare(version) < 0;
isOlderThan(version: Version): VersionType {
const [bool, type] = this.compare(version);
return bool < 0 ? type : VersionType.EQUAL;
}
isEqual(version: Version): boolean {
return this.compare(version) === 0;
const [bool] = this.compare(version);
return bool === 0;
}
isNewerThan(version: Version): boolean {
return this.compare(version) > 0;
isNewerThan(version: Version): VersionType {
const [bool, type] = this.compare(version);
return bool > 0 ? type : VersionType.EQUAL;
}
}

View file

@ -3,21 +3,47 @@ import { Version } from '../domain.constant';
export enum DatabaseExtension {
CUBE = 'cube',
EARTH_DISTANCE = 'earthdistance',
VECTOR = 'vector',
VECTORS = 'vectors',
}
export type VectorExtension = DatabaseExtension.VECTOR | DatabaseExtension.VECTORS;
export enum VectorIndex {
CLIP = 'clip_index',
FACE = 'face_index',
}
export enum DatabaseLock {
GeodataImport = 100,
Migrations = 200,
StorageTemplateMigration = 420,
CLIPDimSize = 512,
}
export const extName: Record<DatabaseExtension, string> = {
cube: 'cube',
earthdistance: 'earthdistance',
vector: 'pgvector',
vectors: 'pgvecto.rs',
} as const;
export interface VectorUpdateResult {
restartRequired: boolean;
}
export const IDatabaseRepository = 'IDatabaseRepository';
export interface IDatabaseRepository {
getExtensionVersion(extensionName: string): Promise<Version | null>;
getAvailableExtensionVersion(extension: DatabaseExtension): Promise<Version | null>;
getPreferredVectorExtension(): VectorExtension;
getPostgresVersion(): Promise<Version>;
createExtension(extension: DatabaseExtension): Promise<void>;
updateExtension(extension: DatabaseExtension, version?: Version): Promise<void>;
updateVectorExtension(extension: VectorExtension, version?: Version): Promise<VectorUpdateResult>;
reindex(index: VectorIndex): Promise<void>;
shouldReindex(name: VectorIndex): Promise<boolean>;
runMigrations(options?: { transaction?: 'all' | 'none' | 'each' }): Promise<void>;
withLock<R>(lock: DatabaseLock, callback: () => Promise<R>): Promise<R>;
isBusy(lock: DatabaseLock): boolean;

View file

@ -7,7 +7,7 @@ export type Embedding = number[];
export interface EmbeddingSearch {
userIds: string[];
embedding: Embedding;
numResults?: number;
numResults: number;
withArchived?: boolean;
}

View file

@ -1,3 +1,4 @@
import { DatabaseExtension } from '@app/domain/repositories/database.repository';
import { DataSource } from 'typeorm';
import { PostgresConnectionOptions } from 'typeorm/driver/postgres/PostgresConnectionOptions.js';
@ -27,3 +28,6 @@ export const databaseConfig: PostgresConnectionOptions = {
// this export is used by TypeORM commands in package.json#scripts
export const dataSource = new DataSource(databaseConfig);
export const vectorExt =
process.env.VECTOR_EXTENSION === 'pgvector' ? DatabaseExtension.VECTOR : DatabaseExtension.VECTORS;

View file

@ -5,7 +5,7 @@ import { LogLevel } from './entities';
const LOG_LEVELS = [LogLevel.VERBOSE, LogLevel.DEBUG, LogLevel.LOG, LogLevel.WARN, LogLevel.ERROR, LogLevel.FATAL];
export class ImmichLogger extends ConsoleLogger {
private static logLevels: LogLevel[] = [LogLevel.WARN, LogLevel.ERROR, LogLevel.FATAL];
private static logLevels: LogLevel[] = [LogLevel.LOG, LogLevel.WARN, LogLevel.ERROR, LogLevel.FATAL];
constructor(context: string) {
super(context);

View file

@ -1,11 +1,13 @@
import { getCLIPModelInfo } from '@app/domain/smart-info/smart-info.constant';
import { MigrationInterface, QueryRunner } from 'typeorm';
import { vectorExt } from '@app/infra/database.config';
export class UsePgVectors1700713871511 implements MigrationInterface {
name = 'UsePgVectors1700713871511';
public async up(queryRunner: QueryRunner): Promise<void> {
await queryRunner.query(`CREATE EXTENSION IF NOT EXISTS vectors`);
await queryRunner.query(`SET search_path TO "$user", public, vectors`);
await queryRunner.query(`CREATE EXTENSION IF NOT EXISTS ${vectorExt}`);
const faceDimQuery = await queryRunner.query(`
SELECT CARDINALITY(embedding::real[]) as dimsize
FROM asset_faces

View file

@ -1,16 +1,20 @@
import { MigrationInterface, QueryRunner } from 'typeorm';
import { vectorExt } from '../database.config';
import { DatabaseExtension } from '@app/domain/repositories/database.repository';
export class AddCLIPEmbeddingIndex1700713994428 implements MigrationInterface {
name = 'AddCLIPEmbeddingIndex1700713994428';
public async up(queryRunner: QueryRunner): Promise<void> {
if (vectorExt === DatabaseExtension.VECTORS) {
await queryRunner.query(`SET vectors.pgvector_compatibility=on`);
}
await queryRunner.query(`SET search_path TO "$user", public, vectors`);
await queryRunner.query(`
CREATE INDEX IF NOT EXISTS clip_index ON smart_search
USING vectors (embedding cosine_ops) WITH (options = $$
[indexing.hnsw]
m = 16
ef_construction = 300
$$);`);
USING hnsw (embedding vector_cosine_ops)
WITH (ef_construction = 300, m = 16)`);
}
public async down(queryRunner: QueryRunner): Promise<void> {

View file

@ -1,16 +1,20 @@
import { MigrationInterface, QueryRunner } from 'typeorm';
import { vectorExt } from '../database.config';
import { DatabaseExtension } from '@app/domain/repositories/database.repository';
export class AddFaceEmbeddingIndex1700714033632 implements MigrationInterface {
name = 'AddFaceEmbeddingIndex1700714033632';
public async up(queryRunner: QueryRunner): Promise<void> {
if (vectorExt === DatabaseExtension.VECTORS) {
await queryRunner.query(`SET vectors.pgvector_compatibility=on`);
}
await queryRunner.query(`SET search_path TO "$user", public, vectors`);
await queryRunner.query(`
CREATE INDEX IF NOT EXISTS face_index ON asset_faces
USING vectors (embedding cosine_ops) WITH (options = $$
[indexing.hnsw]
m = 16
ef_construction = 300
$$);`);
USING hnsw (embedding vector_cosine_ops)
WITH (ef_construction = 300, m = 16)`);
}
public async down(queryRunner: QueryRunner): Promise<void> {

View file

@ -0,0 +1,14 @@
import { MigrationInterface, QueryRunner } from 'typeorm';
export class AddVectorsToSearchPath1707000751533 implements MigrationInterface {
public async up(queryRunner: QueryRunner): Promise<void> {
const res = await queryRunner.query(`SELECT current_database() as db`);
const databaseName = res[0]['db'];
await queryRunner.query(`ALTER DATABASE ${databaseName} SET search_path TO "$user", public, vectors`);
}
public async down(queryRunner: QueryRunner): Promise<void> {
const databaseName = await queryRunner.query(`SELECT current_database()`);
await queryRunner.query(`ALTER DATABASE ${databaseName} SET search_path TO "$user", public`);
}
}

View file

@ -1,21 +1,60 @@
import { DatabaseExtension, DatabaseLock, IDatabaseRepository, Version } from '@app/domain';
import {
DatabaseExtension,
DatabaseLock,
IDatabaseRepository,
VectorExtension,
VectorIndex,
VectorUpdateResult,
Version,
VersionType,
extName,
} from '@app/domain';
import { vectorExt } from '@app/infra/database.config';
import { Injectable } from '@nestjs/common';
import { InjectDataSource } from '@nestjs/typeorm';
import AsyncLock from 'async-lock';
import { DataSource, QueryRunner } from 'typeorm';
import { DataSource, EntityManager, QueryRunner } from 'typeorm';
import { isValidInteger } from '../infra.utils';
import { ImmichLogger } from '../logger';
@Injectable()
export class DatabaseRepository implements IDatabaseRepository {
private logger = new ImmichLogger(DatabaseRepository.name);
readonly asyncLock = new AsyncLock();
constructor(@InjectDataSource() private dataSource: DataSource) {}
async getExtensionVersion(extension: DatabaseExtension): Promise<Version | null> {
const res = await this.dataSource.query(`SELECT extversion FROM pg_extension WHERE extname = $1`, [extension]);
const version = res[0]?.['extversion'];
const extVersion = res[0]?.['extversion'];
if (extVersion == null) {
return null;
}
const version = Version.fromString(extVersion);
if (version.isEqual(new Version(0, 1, 1))) {
return new Version(0, 1, 11);
}
return version;
}
async getAvailableExtensionVersion(extension: DatabaseExtension): Promise<Version | null> {
const res = await this.dataSource.query(
`
SELECT version FROM pg_available_extension_versions
WHERE name = $1 AND installed = false
ORDER BY version DESC`,
[extension],
);
const version = res[0]?.['version'];
return version == null ? null : Version.fromString(version);
}
getPreferredVectorExtension(): VectorExtension {
return vectorExt;
}
async getPostgresVersion(): Promise<Version> {
const res = await this.dataSource.query(`SHOW server_version`);
return Version.fromString(res[0]['server_version']);
@ -25,6 +64,129 @@ export class DatabaseRepository implements IDatabaseRepository {
await this.dataSource.query(`CREATE EXTENSION IF NOT EXISTS ${extension}`);
}
async updateExtension(extension: DatabaseExtension, version?: Version): Promise<void> {
await this.dataSource.query(`ALTER EXTENSION ${extension} UPDATE${version ? ` TO '${version}'` : ''}`);
}
async updateVectorExtension(extension: VectorExtension, version?: Version): Promise<VectorUpdateResult> {
const curVersion = await this.getExtensionVersion(extension);
if (!curVersion) {
throw new Error(`${extName[extension]} extension is not installed`);
}
const minorOrMajor = version && curVersion.isOlderThan(version) >= VersionType.MINOR;
const isVectors = extension === DatabaseExtension.VECTORS;
let restartRequired = false;
await this.dataSource.manager.transaction(async (manager) => {
await this.setSearchPath(manager);
if (minorOrMajor && isVectors) {
await this.updateVectorsSchema(manager, curVersion);
}
await manager.query(`ALTER EXTENSION ${extension} UPDATE${version ? ` TO '${version}'` : ''}`);
if (!minorOrMajor) {
return;
}
if (isVectors) {
await manager.query('SELECT pgvectors_upgrade()');
restartRequired = true;
} else {
await this.reindex(VectorIndex.CLIP);
await this.reindex(VectorIndex.FACE);
}
});
return { restartRequired };
}
async reindex(index: VectorIndex): Promise<void> {
try {
await this.dataSource.query(`REINDEX INDEX ${index}`);
} catch (error) {
if (vectorExt === DatabaseExtension.VECTORS) {
this.logger.warn(`Could not reindex index ${index}. Attempting to auto-fix.`);
const table = index === VectorIndex.CLIP ? 'smart_search' : 'asset_faces';
const dimSize = await this.getDimSize(table);
await this.dataSource.manager.transaction(async (manager) => {
await this.setSearchPath(manager);
await manager.query(`DROP INDEX IF EXISTS ${index}`);
await manager.query(`ALTER TABLE ${table} ALTER COLUMN embedding SET DATA TYPE real[]`);
await manager.query(`ALTER TABLE ${table} ALTER COLUMN embedding SET DATA TYPE vector(${dimSize})`);
await manager.query(`SET vectors.pgvector_compatibility=on`);
await manager.query(`
CREATE INDEX IF NOT EXISTS ${index} ON ${table}
USING hnsw (embedding vector_cosine_ops)
WITH (ef_construction = 300, m = 16)`);
});
} else {
throw error;
}
}
}
async shouldReindex(name: VectorIndex): Promise<boolean> {
if (vectorExt !== DatabaseExtension.VECTORS) {
return false;
}
try {
const res = await this.dataSource.query(
`
SELECT idx_status
FROM pg_vector_index_stat
WHERE indexname = $1`,
[name],
);
return res[0]?.['idx_status'] === 'UPGRADE';
} catch (error) {
const message: string = (error as any).message;
if (message.includes('index is not existing')) {
return true;
} else if (message.includes('relation "pg_vector_index_stat" does not exist')) {
return false;
}
throw error;
}
}
private async setSearchPath(manager: EntityManager): Promise<void> {
await manager.query(`SET search_path TO "$user", public, vectors`);
}
private async updateVectorsSchema(manager: EntityManager, curVersion: Version): Promise<void> {
await manager.query('CREATE SCHEMA IF NOT EXISTS vectors');
await manager.query(`UPDATE pg_catalog.pg_extension SET extversion = $1 WHERE extname = $2`, [
curVersion.toString(),
DatabaseExtension.VECTORS,
]);
await manager.query('UPDATE pg_catalog.pg_extension SET extrelocatable = true WHERE extname = $1', [
DatabaseExtension.VECTORS,
]);
await manager.query('ALTER EXTENSION vectors SET SCHEMA vectors');
await manager.query('UPDATE pg_catalog.pg_extension SET extrelocatable = false WHERE extname = $1', [
DatabaseExtension.VECTORS,
]);
}
private async getDimSize(table: string, column = 'embedding'): Promise<number> {
const res = await this.dataSource.query(`
SELECT atttypmod as dimsize
FROM pg_attribute f
JOIN pg_class c ON c.oid = f.attrelid
WHERE c.relkind = 'r'::char
AND f.attnum > 0
AND c.relname = '${table}'
AND f.attname = '${column}'`);
const dimSize = res[0]['dimsize'];
if (!isValidInteger(dimSize, { min: 1, max: 2 ** 16 })) {
throw new Error(`Could not retrieve dimension size`);
}
return dimSize;
}
async runMigrations(options?: { transaction?: 'all' | 'none' | 'each' }): Promise<void> {
await this.dataSource.runMigrations(options);
}

View file

@ -1,10 +1,18 @@
import { Embedding, EmbeddingSearch, FaceEmbeddingSearch, FaceSearchResult, ISmartInfoRepository } from '@app/domain';
import {
DatabaseExtension,
Embedding,
EmbeddingSearch,
FaceEmbeddingSearch,
FaceSearchResult,
ISmartInfoRepository,
} from '@app/domain';
import { getCLIPModelInfo } from '@app/domain/smart-info/smart-info.constant';
import { AssetEntity, AssetFaceEntity, SmartInfoEntity, SmartSearchEntity } from '@app/infra/entities';
import { ImmichLogger } from '@app/infra/logger';
import { Injectable } from '@nestjs/common';
import { InjectRepository } from '@nestjs/typeorm';
import { Repository } from 'typeorm';
import { vectorExt } from '../database.config';
import { DummyValue, GenerateSql } from '../infra.util';
import { asVector, isValidInteger } from '../infra.utils';
@ -44,16 +52,20 @@ export class SmartInfoRepository implements ISmartInfoRepository {
params: [{ userIds: [DummyValue.UUID], embedding: Array.from({ length: 512 }, Math.random), numResults: 100 }],
})
async searchCLIP({ userIds, embedding, numResults, withArchived }: EmbeddingSearch): Promise<AssetEntity[]> {
if (!isValidInteger(numResults, { min: 1 })) {
throw new Error(`Invalid value for 'numResults': ${numResults}`);
}
// setting this too low messes with prefilter recall
numResults = Math.max(numResults, 64);
let results: AssetEntity[] = [];
await this.assetRepository.manager.transaction(async (manager) => {
await manager.query(`SET LOCAL vectors.enable_prefilter = on`);
let query = manager
const query = manager
.createQueryBuilder(AssetEntity, 'a')
.innerJoin('a.smartSearch', 's')
.leftJoinAndSelect('a.exifInfo', 'e')
.where('a.ownerId IN (:...userIds )')
.orderBy('s.embedding <=> :embedding')
.setParameters({ userIds, embedding: asVector(embedding) });
@ -61,15 +73,9 @@ export class SmartInfoRepository implements ISmartInfoRepository {
query.andWhere('a.isArchived = false');
}
query.andWhere('a.isVisible = true').andWhere('a.fileCreatedAt < NOW()');
query.limit(numResults);
if (numResults) {
if (!isValidInteger(numResults, { min: 1 })) {
throw new Error(`Invalid value for 'numResults': ${numResults}`);
}
query = query.limit(numResults);
await manager.query(`SET LOCAL vectors.k = '${numResults}'`);
}
await manager.query(this.getRuntimeConfig(numResults));
results = await query.getMany();
});
@ -93,36 +99,34 @@ export class SmartInfoRepository implements ISmartInfoRepository {
maxDistance,
hasPerson,
}: FaceEmbeddingSearch): Promise<FaceSearchResult[]> {
if (!isValidInteger(numResults, { min: 1 })) {
throw new Error(`Invalid value for 'numResults': ${numResults}`);
}
// setting this too low messes with prefilter recall
numResults = Math.max(numResults, 64);
let results: Array<AssetFaceEntity & { distance: number }> = [];
await this.assetRepository.manager.transaction(async (manager) => {
await manager.query(`SET LOCAL vectors.enable_prefilter = on`);
let cte = manager
const cte = manager
.createQueryBuilder(AssetFaceEntity, 'faces')
.select('1 + (faces.embedding <=> :embedding)', 'distance')
.select('faces.embedding <=> :embedding', 'distance')
.innerJoin('faces.asset', 'asset')
.where('asset.ownerId IN (:...userIds )')
.orderBy('1 + (faces.embedding <=> :embedding)')
.orderBy('faces.embedding <=> :embedding')
.setParameters({ userIds, embedding: asVector(embedding) });
if (numResults) {
if (!isValidInteger(numResults, { min: 1 })) {
throw new Error(`Invalid value for 'numResults': ${numResults}`);
}
cte = cte.limit(numResults);
if (numResults > 64) {
// setting k too low messes with prefilter recall
await manager.query(`SET LOCAL vectors.k = '${numResults}'`);
}
}
cte.limit(numResults);
if (hasPerson) {
cte = cte.andWhere('faces."personId" IS NOT NULL');
cte.andWhere('faces."personId" IS NOT NULL');
}
for (const col of this.faceColumns) {
cte.addSelect(`faces.${col}`, col);
}
await manager.query(this.getRuntimeConfig(numResults));
results = await manager
.createQueryBuilder()
.select('res.*')
@ -167,6 +171,9 @@ export class SmartInfoRepository implements ISmartInfoRepository {
this.logger.log(`Updating database CLIP dimension size to ${dimSize}.`);
await this.smartSearchRepository.manager.transaction(async (manager) => {
if (vectorExt === DatabaseExtension.VECTORS) {
await manager.query(`SET vectors.pgvector_compatibility=on`);
}
await manager.query(`DROP TABLE smart_search`);
await manager.query(`
@ -175,12 +182,9 @@ export class SmartInfoRepository implements ISmartInfoRepository {
embedding vector(${dimSize}) NOT NULL )`);
await manager.query(`
CREATE INDEX clip_index ON smart_search
USING vectors (embedding cosine_ops) WITH (options = $$
[indexing.hnsw]
m = 16
ef_construction = 300
$$)`);
CREATE INDEX IF NOT EXISTS clip_index ON smart_search
USING hnsw (embedding vector_cosine_ops)
WITH (ef_construction = 300, m = 16)`);
});
this.logger.log(`Successfully updated database CLIP dimension size from ${currentDimSize} to ${dimSize}.`);
@ -202,4 +206,17 @@ export class SmartInfoRepository implements ISmartInfoRepository {
}
return dimSize;
}
private getRuntimeConfig(numResults?: number): string {
if (vectorExt === DatabaseExtension.VECTOR) {
return 'SET LOCAL hnsw.ef_search = 1000;'; // mitigate post-filter recall
}
let runtimeConfig = 'SET LOCAL vectors.enable_prefilter=on; SET LOCAL vectors.search_mode=vbase;';
if (numResults) {
runtimeConfig += ` SET LOCAL vectors.hnsw_ef_search = ${numResults};`;
}
return runtimeConfig;
}
}

View file

@ -3,9 +3,13 @@
-- SmartInfoRepository.searchCLIP
START TRANSACTION
SET
LOCAL vectors.enable_prefilter = on
LOCAL vectors.enable_prefilter = on;
SET
LOCAL vectors.k = '100'
LOCAL vectors.search_mode = vbase;
SET
LOCAL vectors.hnsw_ef_search = 100;
SELECT
"a"."id" AS "a_id",
"a"."deviceAssetId" AS "a_deviceAssetId",
@ -85,9 +89,13 @@ COMMIT
-- SmartInfoRepository.searchFaces
START TRANSACTION
SET
LOCAL vectors.enable_prefilter = on
LOCAL vectors.enable_prefilter = on;
SET
LOCAL vectors.k = '100'
LOCAL vectors.search_mode = vbase;
SET
LOCAL vectors.hnsw_ef_search = 100;
WITH
"cte" AS (
SELECT
@ -100,7 +108,7 @@ WITH
"faces"."boundingBoxY1" AS "boundingBoxY1",
"faces"."boundingBoxX2" AS "boundingBoxX2",
"faces"."boundingBoxY2" AS "boundingBoxY2",
1 + ("faces"."embedding" <= > $1) AS "distance"
"faces"."embedding" <= > $1 AS "distance"
FROM
"asset_faces" "faces"
INNER JOIN "assets" "asset" ON "asset"."id" = "faces"."assetId"
@ -108,7 +116,7 @@ WITH
WHERE
"asset"."ownerId" IN ($2)
ORDER BY
1 + ("faces"."embedding" <= > $1) ASC
"faces"."embedding" <= > $1 ASC
LIMIT
100
)

View file

@ -3,8 +3,14 @@ import { IDatabaseRepository, Version } from '@app/domain';
export const newDatabaseRepositoryMock = (): jest.Mocked<IDatabaseRepository> => {
return {
getExtensionVersion: jest.fn(),
getAvailableExtensionVersion: jest.fn(),
getPreferredVectorExtension: jest.fn(),
getPostgresVersion: jest.fn().mockResolvedValue(new Version(14, 0, 0)),
createExtension: jest.fn().mockImplementation(() => Promise.resolve()),
updateExtension: jest.fn(),
updateVectorExtension: jest.fn(),
reindex: jest.fn(),
shouldReindex: jest.fn(),
runMigrations: jest.fn(),
withLock: jest.fn().mockImplementation((_, function_: <R>() => Promise<R>) => function_()),
isBusy: jest.fn(),