diff --git a/server/src/interfaces/search.interface.ts b/server/src/interfaces/search.interface.ts index 57523aa940..ce9e2a1940 100644 --- a/server/src/interfaces/search.interface.ts +++ b/server/src/interfaces/search.interface.ts @@ -155,8 +155,9 @@ export interface FaceEmbeddingSearch extends SearchEmbeddingOptions { export interface AssetDuplicateSearch { assetId: string; embedding: Embedding; - userIds: string[]; maxDistance?: number; + type: AssetType; + userIds: string[]; } export interface FaceSearchResult { diff --git a/server/src/queries/search.repository.sql b/server/src/queries/search.repository.sql index 1a4245592b..9efeae6248 100644 --- a/server/src/queries/search.repository.sql +++ b/server/src/queries/search.repository.sql @@ -204,6 +204,7 @@ WITH "asset"."ownerId" IN ($2) AND "asset"."id" != $3 AND "asset"."isVisible" = $4 + AND "asset"."type" = $5 ) AND ("asset"."deletedAt" IS NULL) ORDER BY @@ -216,7 +217,7 @@ SELECT FROM "cte" "res" WHERE - res.distance <= $5 + res.distance <= $6 -- SearchRepository.searchFaces START TRANSACTION diff --git a/server/src/repositories/search.repository.ts b/server/src/repositories/search.repository.ts index 072d452777..f0c5dcb364 100644 --- a/server/src/repositories/search.repository.ts +++ b/server/src/repositories/search.repository.ts @@ -160,6 +160,7 @@ export class SearchRepository implements ISearchRepository { assetId, embedding, maxDistance, + type, userIds, }: AssetDuplicateSearch): Promise { const cte = this.assetRepository.createQueryBuilder('asset'); @@ -171,18 +172,22 @@ export class SearchRepository implements ISearchRepository { .where('asset.ownerId IN (:...userIds )') .andWhere('asset.id != :assetId') .andWhere('asset.isVisible = :isVisible') + .andWhere('asset.type = :type') .orderBy('search.embedding <=> :embedding') .limit(64) - .setParameters({ assetId, embedding: asVector(embedding), isVisible: true, userIds }); + .setParameters({ assetId, embedding: asVector(embedding), isVisible: true, type, userIds }); const builder = this.assetRepository.manager .createQueryBuilder() .addCommonTableExpression(cte, 'cte') .from('cte', 'res') - .select('res.*') - .where('res.distance <= :maxDistance', { maxDistance }); + .select('res.*'); - return builder.getRawMany() as any as Promise; + if (maxDistance) { + builder.where('res.distance <= :maxDistance', { maxDistance }); + } + + return builder.getRawMany() as Promise; } @GenerateSql({ diff --git a/server/src/services/duplicate.service.spec.ts b/server/src/services/duplicate.service.spec.ts index 4560d9024c..79374ea7ae 100644 --- a/server/src/services/duplicate.service.spec.ts +++ b/server/src/services/duplicate.service.spec.ts @@ -215,6 +215,7 @@ describe(SearchService.name, () => { assetId: assetStub.hasEmbedding.id, embedding: assetStub.hasEmbedding.smartSearch!.embedding, maxDistance: 0.03, + type: assetStub.hasEmbedding.type, userIds: [assetStub.hasEmbedding.ownerId], }); expect(assetMock.updateDuplicates).toHaveBeenCalledWith({ @@ -240,6 +241,7 @@ describe(SearchService.name, () => { assetId: assetStub.hasEmbedding.id, embedding: assetStub.hasEmbedding.smartSearch!.embedding, maxDistance: 0.03, + type: assetStub.hasEmbedding.type, userIds: [assetStub.hasEmbedding.ownerId], }); expect(assetMock.updateDuplicates).toHaveBeenCalledWith({ diff --git a/server/src/services/duplicate.service.ts b/server/src/services/duplicate.service.ts index 95a12bd18e..6313ffa21f 100644 --- a/server/src/services/duplicate.service.ts +++ b/server/src/services/duplicate.service.ts @@ -94,6 +94,7 @@ export class DuplicateService { assetId: asset.id, embedding: asset.smartSearch.embedding, maxDistance: machineLearning.duplicateDetection.maxDistance, + type: asset.type, userIds: [asset.ownerId], });