diff --git a/server/src/infra/repositories/smart-info.repository.ts b/server/src/infra/repositories/smart-info.repository.ts index fea1ef24d5..600c780ea4 100644 --- a/server/src/infra/repositories/smart-info.repository.ts +++ b/server/src/infra/repositories/smart-info.repository.ts @@ -11,13 +11,19 @@ import { asVector, isValidInteger } from '../infra.utils'; @Injectable() export class SmartInfoRepository implements ISmartInfoRepository { private logger = new Logger(SmartInfoRepository.name); + private faceColumns: string[]; constructor( @InjectRepository(SmartInfoEntity) private repository: Repository, @InjectRepository(AssetEntity) private assetRepository: Repository, @InjectRepository(AssetFaceEntity) private assetFaceRepository: Repository, @InjectRepository(SmartSearchEntity) private smartSearchRepository: Repository, - ) {} + ) { + this.faceColumns = this.assetFaceRepository.manager.connection + .getMetadata(AssetFaceEntity) + .ownColumns.map((column) => column.propertyName) + .filter((propertyName) => propertyName !== 'embedding'); + } async init(modelName: string): Promise { const { dimSize } = getCLIPModelInfo(modelName); @@ -79,13 +85,15 @@ export class SmartInfoRepository implements ISmartInfoRepository { await manager.query(`SET LOCAL vectors.k = '${numResults}'`); const cte = manager .createQueryBuilder(AssetFaceEntity, 'faces') - .addSelect('1 + (faces.embedding <=> :embedding)', 'distance') + .select('1 + (faces.embedding <=> :embedding)', 'distance') .innerJoin('faces.asset', 'asset') .where('asset.ownerId = :ownerId') .orderBy(`faces.embedding <=> :embedding`) .setParameters({ ownerId, embedding: asVector(embedding) }) .limit(numResults); + this.faceColumns.forEach((col) => cte.addSelect(`faces.${col}`, col)); + results = await manager .createQueryBuilder() .select('res.*') diff --git a/server/src/infra/sql/smart.info.repository.sql b/server/src/infra/sql/smart.info.repository.sql index 44de26ad92..a3931441f3 100644 --- a/server/src/infra/sql/smart.info.repository.sql +++ b/server/src/infra/sql/smart.info.repository.sql @@ -81,15 +81,15 @@ SET WITH "cte" AS ( SELECT - "faces"."id" AS "faces_id", - "faces"."assetId" AS "faces_assetId", - "faces"."personId" AS "faces_personId", - "faces"."imageWidth" AS "faces_imageWidth", - "faces"."imageHeight" AS "faces_imageHeight", - "faces"."boundingBoxX1" AS "faces_boundingBoxX1", - "faces"."boundingBoxY1" AS "faces_boundingBoxY1", - "faces"."boundingBoxX2" AS "faces_boundingBoxX2", - "faces"."boundingBoxY2" AS "faces_boundingBoxY2", + "faces"."id" AS "id", + "faces"."assetId" AS "assetId", + "faces"."personId" AS "personId", + "faces"."imageWidth" AS "imageWidth", + "faces"."imageHeight" AS "imageHeight", + "faces"."boundingBoxX1" AS "boundingBoxX1", + "faces"."boundingBoxY1" AS "boundingBoxY1", + "faces"."boundingBoxX2" AS "boundingBoxX2", + "faces"."boundingBoxY2" AS "boundingBoxY2", 1 + ("faces"."embedding" <= > $1) AS "distance" FROM "asset_faces" "faces"