From 92ca447f33bc5b319689e900abe153ec9316f1f1 Mon Sep 17 00:00:00 2001 From: Jason Rasmussen Date: Thu, 12 Jan 2023 21:15:45 -0500 Subject: [PATCH] refactor(server): use UserService (#1309) * refactor: communication gateway * refactor: share strategy * refactor: communication module --- .../communication/communication.gateway.ts | 58 ++++--------------- .../communication/communication.module.ts | 10 +--- .../communication/communication.service.ts | 4 -- .../immich/src/api-v1/share/share.core.ts | 24 +++----- .../immich/src/api-v1/share/share.service.ts | 43 ++++++++++++-- .../modules/immich-jwt/immich-jwt.module.ts | 4 +- .../immich-jwt/immich-jwt.service.spec.ts | 18 ++++-- .../modules/immich-jwt/immich-jwt.service.ts | 47 ++++++++++----- .../immich-jwt/strategies/jwt.strategy.ts | 17 ++---- .../strategies/public-share.strategy.ts | 36 +----------- 10 files changed, 113 insertions(+), 148 deletions(-) delete mode 100644 server/apps/immich/src/api-v1/communication/communication.service.ts diff --git a/server/apps/immich/src/api-v1/communication/communication.gateway.ts b/server/apps/immich/src/api-v1/communication/communication.gateway.ts index 52e40becd3..a0f5a99952 100644 --- a/server/apps/immich/src/api-v1/communication/communication.gateway.ts +++ b/server/apps/immich/src/api-v1/communication/communication.gateway.ts @@ -1,70 +1,32 @@ -import { OnGatewayConnection, OnGatewayDisconnect, WebSocketGateway, WebSocketServer } from '@nestjs/websockets'; -import { Socket, Server } from 'socket.io'; -import { ImmichJwtService, JwtValidationResult } from '../../modules/immich-jwt/immich-jwt.service'; import { Logger } from '@nestjs/common'; -import { InjectRepository } from '@nestjs/typeorm'; -import { UserEntity } from '@app/infra'; -import { Repository } from 'typeorm'; -import cookieParser from 'cookie'; -import { IMMICH_ACCESS_COOKIE } from '../../constants/jwt.constant'; +import { OnGatewayConnection, OnGatewayDisconnect, WebSocketGateway, WebSocketServer } from '@nestjs/websockets'; +import { Server, Socket } from 'socket.io'; +import { ImmichJwtService } from '../../modules/immich-jwt/immich-jwt.service'; @WebSocketGateway({ cors: true }) export class CommunicationGateway implements OnGatewayConnection, OnGatewayDisconnect { - constructor( - private immichJwtService: ImmichJwtService, + private logger = new Logger(CommunicationGateway.name); - @InjectRepository(UserEntity) - private userRepository: Repository, - ) {} + constructor(private immichJwtService: ImmichJwtService) {} @WebSocketServer() server!: Server; handleDisconnect(client: Socket) { client.leave(client.nsp.name); - - Logger.log(`Client ${client.id} disconnected from Websocket`, 'WebsocketConnectionEvent'); + this.logger.log(`Client ${client.id} disconnected from Websocket`); } async handleConnection(client: Socket) { try { - Logger.log(`New websocket connection: ${client.id}`, 'WebsocketConnectionEvent'); - let accessToken = ''; + this.logger.log(`New websocket connection: ${client.id}`); - if (client.handshake.headers.cookie != undefined) { - const cookies = cookieParser.parse(client.handshake.headers.cookie); - if (cookies[IMMICH_ACCESS_COOKIE]) { - accessToken = cookies[IMMICH_ACCESS_COOKIE]; - } else { - client.emit('error', 'unauthorized'); - client.disconnect(); - return; - } - } else if (client.handshake.headers.authorization != undefined) { - accessToken = client.handshake.headers.authorization.split(' ')[1]; + const user = await this.immichJwtService.validateSocket(client); + if (user) { + client.join(user.id); } else { client.emit('error', 'unauthorized'); client.disconnect(); - return; } - - const res: JwtValidationResult = accessToken - ? await this.immichJwtService.validateToken(accessToken) - : { status: false, userId: null }; - - if (!res.status || res.userId == null) { - client.emit('error', 'unauthorized'); - client.disconnect(); - return; - } - - const user = await this.userRepository.findOne({ where: { id: res.userId } }); - if (!user) { - client.emit('error', 'unauthorized'); - client.disconnect(); - return; - } - - client.join(user.id); } catch (e) { // Logger.error(`Error establish websocket conneciton ${e}`, 'HandleWebscoketConnection'); } diff --git a/server/apps/immich/src/api-v1/communication/communication.module.ts b/server/apps/immich/src/api-v1/communication/communication.module.ts index 73026dfb9b..1e3fd66618 100644 --- a/server/apps/immich/src/api-v1/communication/communication.module.ts +++ b/server/apps/immich/src/api-v1/communication/communication.module.ts @@ -1,16 +1,10 @@ import { Module } from '@nestjs/common'; -import { CommunicationService } from './communication.service'; import { CommunicationGateway } from './communication.gateway'; import { ImmichJwtModule } from '../../modules/immich-jwt/immich-jwt.module'; -import { ImmichJwtService } from '../../modules/immich-jwt/immich-jwt.service'; -import { JwtModule } from '@nestjs/jwt'; -import { jwtConfig } from '../../config/jwt.config'; -import { TypeOrmModule } from '@nestjs/typeorm'; -import { UserEntity } from '@app/infra'; @Module({ - imports: [TypeOrmModule.forFeature([UserEntity]), ImmichJwtModule, JwtModule.register(jwtConfig)], - providers: [CommunicationGateway, CommunicationService, ImmichJwtService], + imports: [ImmichJwtModule], + providers: [CommunicationGateway], exports: [CommunicationGateway], }) export class CommunicationModule {} diff --git a/server/apps/immich/src/api-v1/communication/communication.service.ts b/server/apps/immich/src/api-v1/communication/communication.service.ts deleted file mode 100644 index bca3987766..0000000000 --- a/server/apps/immich/src/api-v1/communication/communication.service.ts +++ /dev/null @@ -1,4 +0,0 @@ -import { Injectable } from '@nestjs/common'; - -@Injectable() -export class CommunicationService {} diff --git a/server/apps/immich/src/api-v1/share/share.core.ts b/server/apps/immich/src/api-v1/share/share.core.ts index aeadf5b7de..f5c1f6f182 100644 --- a/server/apps/immich/src/api-v1/share/share.core.ts +++ b/server/apps/immich/src/api-v1/share/share.core.ts @@ -32,7 +32,7 @@ export class ShareCore { } } - async getSharedLinks(userId: string): Promise { + getSharedLinks(userId: string): Promise { return this.sharedLinkRepository.get(userId); } @@ -46,27 +46,19 @@ export class ShareCore { return await this.sharedLinkRepository.remove(link); } - async getSharedLinkById(id: string): Promise { - const link = await this.sharedLinkRepository.getById(id); - if (!link) { - throw new BadRequestException('Shared link not found'); - } - - return link; + getSharedLinkById(id: string): Promise { + return this.sharedLinkRepository.getById(id); } - async getSharedLinkByKey(key: string): Promise { - const link = await this.sharedLinkRepository.getByKey(key); - - if (!link) { - throw new BadRequestException(); - } - - return link; + getSharedLinkByKey(key: string): Promise { + return this.sharedLinkRepository.getByKey(key); } async updateAssetsInSharedLink(sharedLinkId: string, assets: AssetEntity[]) { const link = await this.getSharedLinkById(sharedLinkId); + if (!link) { + throw new BadRequestException('Shared link not found'); + } link.assets = assets; diff --git a/server/apps/immich/src/api-v1/share/share.service.ts b/server/apps/immich/src/api-v1/share/share.service.ts index c3c9e63b80..2b35d5c061 100644 --- a/server/apps/immich/src/api-v1/share/share.service.ts +++ b/server/apps/immich/src/api-v1/share/share.service.ts @@ -1,4 +1,12 @@ -import { ForbiddenException, Inject, Injectable, Logger } from '@nestjs/common'; +import { + BadRequestException, + ForbiddenException, + Inject, + Injectable, + Logger, + UnauthorizedException, +} from '@nestjs/common'; +import { UserService } from '@app/domain'; import { AuthUserDto } from '../../decorators/auth-user.decorator'; import { EditSharedLinkDto } from './dto/edit-shared-link.dto'; import { mapSharedLinkToResponseDto, SharedLinkResponseDto } from './response-dto/shared-link-response.dto'; @@ -13,9 +21,31 @@ export class ShareService { constructor( @Inject(ISharedLinkRepository) sharedLinkRepository: ISharedLinkRepository, + private userService: UserService, ) { this.shareCore = new ShareCore(sharedLinkRepository); } + + async validate(key: string): Promise { + const link = await this.shareCore.getSharedLinkByKey(key); + if (link) { + if (!link.expiresAt || new Date(link.expiresAt) > new Date()) { + const user = await this.userService.getUserById(link.userId).catch(() => null); + if (user) { + return { + id: user.id, + email: user.email, + isAdmin: user.isAdmin, + isPublicUser: true, + sharedLinkId: link.id, + isAllowUpload: link.allowUpload, + }; + } + } + } + throw new UnauthorizedException(); + } + async getAll(authUser: AuthUserDto): Promise { const links = await this.shareCore.getSharedLinks(authUser.id); return links.map(mapSharedLinkToResponseDto); @@ -26,13 +56,14 @@ export class ShareService { throw new ForbiddenException(); } - const link = await this.shareCore.getSharedLinkById(authUser.sharedLinkId); - - return mapSharedLinkToResponseDto(link); + return this.getById(authUser.sharedLinkId); } async getById(id: string): Promise { const link = await this.shareCore.getSharedLinkById(id); + if (!link) { + throw new BadRequestException('Shared link not found'); + } return mapSharedLinkToResponseDto(link); } @@ -43,12 +74,14 @@ export class ShareService { async getByKey(key: string): Promise { const link = await this.shareCore.getSharedLinkByKey(key); + if (!link) { + throw new BadRequestException('Shared link not found'); + } return mapSharedLinkToResponseDto(link); } async edit(id: string, authUser: AuthUserDto, dto: EditSharedLinkDto) { const link = await this.shareCore.updateSharedLink(id, authUser.id, dto); - return mapSharedLinkToResponseDto(link); } } diff --git a/server/apps/immich/src/modules/immich-jwt/immich-jwt.module.ts b/server/apps/immich/src/modules/immich-jwt/immich-jwt.module.ts index 74cdfd15e9..94bd8b772c 100644 --- a/server/apps/immich/src/modules/immich-jwt/immich-jwt.module.ts +++ b/server/apps/immich/src/modules/immich-jwt/immich-jwt.module.ts @@ -3,15 +3,13 @@ import { ImmichJwtService } from './immich-jwt.service'; import { JwtModule } from '@nestjs/jwt'; import { jwtConfig } from '../../config/jwt.config'; import { JwtStrategy } from './strategies/jwt.strategy'; -import { TypeOrmModule } from '@nestjs/typeorm'; -import { UserEntity } from '@app/infra'; import { APIKeyModule } from '../../api-v1/api-key/api-key.module'; import { APIKeyStrategy } from './strategies/api-key.strategy'; import { ShareModule } from '../../api-v1/share/share.module'; import { PublicShareStrategy } from './strategies/public-share.strategy'; @Module({ - imports: [JwtModule.register(jwtConfig), TypeOrmModule.forFeature([UserEntity]), APIKeyModule, ShareModule], + imports: [JwtModule.register(jwtConfig), APIKeyModule, ShareModule], providers: [ImmichJwtService, JwtStrategy, APIKeyStrategy, PublicShareStrategy], exports: [ImmichJwtService], }) diff --git a/server/apps/immich/src/modules/immich-jwt/immich-jwt.service.spec.ts b/server/apps/immich/src/modules/immich-jwt/immich-jwt.service.spec.ts index 6ffe150247..41c549708a 100644 --- a/server/apps/immich/src/modules/immich-jwt/immich-jwt.service.spec.ts +++ b/server/apps/immich/src/modules/immich-jwt/immich-jwt.service.spec.ts @@ -5,9 +5,11 @@ import { UserEntity } from '@app/infra'; import { LoginResponseDto } from '../../api-v1/auth/response-dto/login-response.dto'; import { AuthType } from '../../constants/jwt.constant'; import { ImmichJwtService } from './immich-jwt.service'; +import { UserService } from '@app/domain'; describe('ImmichJwtService', () => { let jwtServiceMock: jest.Mocked; + let userServiceMock: jest.Mocked; let sut: ImmichJwtService; beforeEach(() => { @@ -16,7 +18,11 @@ describe('ImmichJwtService', () => { verifyAsync: jest.fn(), } as unknown as jest.Mocked; - sut = new ImmichJwtService(jwtServiceMock); + userServiceMock = { + getUserById: jest.fn(), + } as unknown as jest.Mocked; + + sut = new ImmichJwtService(jwtServiceMock, userServiceMock); }); afterEach(() => { @@ -102,7 +108,7 @@ describe('ImmichJwtService', () => { const request = { headers: {}, } as Request; - const token = sut.extractJwtFromHeader(request); + const token = sut.extractJwtFromHeader(request.headers); expect(token).toBe(null); }); @@ -119,15 +125,15 @@ describe('ImmichJwtService', () => { }, } as Request; - expect(sut.extractJwtFromHeader(upper)).toBe('token'); - expect(sut.extractJwtFromHeader(lower)).toBe('token'); + expect(sut.extractJwtFromHeader(upper.headers)).toBe('token'); + expect(sut.extractJwtFromHeader(lower.headers)).toBe('token'); }); }); describe('extracJwtFromCookie', () => { it('should handle no cookie', () => { const request = {} as Request; - const token = sut.extractJwtFromCookie(request); + const token = sut.extractJwtFromCookie(request.cookies); expect(token).toBe(null); }); @@ -137,7 +143,7 @@ describe('ImmichJwtService', () => { immich_access_token: 'cookie', }, } as Request; - const token = sut.extractJwtFromCookie(request); + const token = sut.extractJwtFromCookie(request.cookies); expect(token).toBe('cookie'); }); }); diff --git a/server/apps/immich/src/modules/immich-jwt/immich-jwt.service.ts b/server/apps/immich/src/modules/immich-jwt/immich-jwt.service.ts index 6765d24123..9b2f7e65c9 100644 --- a/server/apps/immich/src/modules/immich-jwt/immich-jwt.service.ts +++ b/server/apps/immich/src/modules/immich-jwt/immich-jwt.service.ts @@ -1,10 +1,13 @@ import { UserEntity } from '@app/infra'; import { Injectable, Logger } from '@nestjs/common'; import { JwtService } from '@nestjs/jwt'; -import { Request } from 'express'; +import { IncomingHttpHeaders } from 'http'; import { JwtPayloadDto } from '../../api-v1/auth/dto/jwt-payload.dto'; import { LoginResponseDto, mapLoginResponse } from '../../api-v1/auth/response-dto/login-response.dto'; import { AuthType, IMMICH_ACCESS_COOKIE, IMMICH_AUTH_TYPE_COOKIE, jwtSecret } from '../../constants/jwt.constant'; +import { Socket } from 'socket.io'; +import cookieParser from 'cookie'; +import { UserResponseDto, UserService } from '@app/domain'; export type JwtValidationResult = { status: boolean; @@ -13,7 +16,7 @@ export type JwtValidationResult = { @Injectable() export class ImmichJwtService { - constructor(private jwtService: JwtService) {} + constructor(private jwtService: JwtService, private userService: UserService) {} public getCookieNames() { return [IMMICH_ACCESS_COOKIE, IMMICH_AUTH_TYPE_COOKIE]; @@ -51,22 +54,40 @@ export class ImmichJwtService { } } - public extractJwtFromHeader(req: Request) { - if ( - req.headers.authorization && - (req.headers.authorization.split(' ')[0] === 'Bearer' || req.headers.authorization.split(' ')[0] === 'bearer') - ) { - const accessToken = req.headers.authorization.split(' ')[1]; - return accessToken; + public extractJwtFromHeader(headers: IncomingHttpHeaders) { + if (!headers.authorization) { + return null; + } + const [type, accessToken] = headers.authorization.split(' '); + if (type.toLowerCase() !== 'bearer') { + return null; + } + + return accessToken; + } + + public extractJwtFromCookie(cookies: Record) { + return cookies?.[IMMICH_ACCESS_COOKIE] || null; + } + + public async validateSocket(client: Socket): Promise { + const headers = client.handshake.headers; + const accessToken = + this.extractJwtFromCookie(cookieParser.parse(headers.cookie || '')) || this.extractJwtFromHeader(headers); + + if (accessToken) { + const { userId, status } = await this.validateToken(accessToken); + if (userId && status) { + const user = await this.userService.getUserById(userId).catch(() => null); + if (user) { + return user; + } + } } return null; } - public extractJwtFromCookie(req: Request) { - return req.cookies?.[IMMICH_ACCESS_COOKIE] || null; - } - private async generateToken(payload: JwtPayloadDto) { return this.jwtService.sign({ ...payload, diff --git a/server/apps/immich/src/modules/immich-jwt/strategies/jwt.strategy.ts b/server/apps/immich/src/modules/immich-jwt/strategies/jwt.strategy.ts index b11b79336c..4150664913 100644 --- a/server/apps/immich/src/modules/immich-jwt/strategies/jwt.strategy.ts +++ b/server/apps/immich/src/modules/immich-jwt/strategies/jwt.strategy.ts @@ -1,9 +1,7 @@ -import { UserEntity } from '@app/infra'; import { Injectable, UnauthorizedException } from '@nestjs/common'; import { PassportStrategy } from '@nestjs/passport'; -import { InjectRepository } from '@nestjs/typeorm'; import { ExtractJwt, Strategy, StrategyOptions } from 'passport-jwt'; -import { Repository } from 'typeorm'; +import { UserService } from '@app/domain'; import { JwtPayloadDto } from '../../../api-v1/auth/dto/jwt-payload.dto'; import { jwtSecret } from '../../../constants/jwt.constant'; import { AuthUserDto } from '../../../decorators/auth-user.decorator'; @@ -13,15 +11,11 @@ export const JWT_STRATEGY = 'jwt'; @Injectable() export class JwtStrategy extends PassportStrategy(Strategy, JWT_STRATEGY) { - constructor( - @InjectRepository(UserEntity) - private usersRepository: Repository, - immichJwtService: ImmichJwtService, - ) { + constructor(private userService: UserService, immichJwtService: ImmichJwtService) { super({ jwtFromRequest: ExtractJwt.fromExtractors([ - immichJwtService.extractJwtFromCookie, - immichJwtService.extractJwtFromHeader, + (req) => immichJwtService.extractJwtFromCookie(req.cookies), + (req) => immichJwtService.extractJwtFromHeader(req.headers), ]), ignoreExpiration: false, secretOrKey: jwtSecret, @@ -30,8 +24,7 @@ export class JwtStrategy extends PassportStrategy(Strategy, JWT_STRATEGY) { async validate(payload: JwtPayloadDto): Promise { const { userId } = payload; - const user = await this.usersRepository.findOne({ where: { id: userId } }); - + const user = await this.userService.getUserById(userId).catch(() => null); if (!user) { throw new UnauthorizedException('Failure to validate JWT payload'); } diff --git a/server/apps/immich/src/modules/immich-jwt/strategies/public-share.strategy.ts b/server/apps/immich/src/modules/immich-jwt/strategies/public-share.strategy.ts index c4953e518b..1c284c6da8 100644 --- a/server/apps/immich/src/modules/immich-jwt/strategies/public-share.strategy.ts +++ b/server/apps/immich/src/modules/immich-jwt/strategies/public-share.strategy.ts @@ -1,9 +1,6 @@ -import { UserEntity } from '@app/infra'; -import { Injectable, UnauthorizedException } from '@nestjs/common'; +import { Injectable } from '@nestjs/common'; import { PassportStrategy } from '@nestjs/passport'; -import { InjectRepository } from '@nestjs/typeorm'; import { IStrategyOptions, Strategy } from 'passport-http-header-strategy'; -import { Repository } from 'typeorm'; import { ShareService } from '../../../api-v1/share/share.service'; import { AuthUserDto } from '../../../decorators/auth-user.decorator'; @@ -16,38 +13,11 @@ const options: IStrategyOptions = { @Injectable() export class PublicShareStrategy extends PassportStrategy(Strategy, PUBLIC_SHARE_STRATEGY) { - constructor( - private shareService: ShareService, - @InjectRepository(UserEntity) - private usersRepository: Repository, - ) { + constructor(private shareService: ShareService) { super(options); } async validate(key: string): Promise { - const validatedLink = await this.shareService.getByKey(key); - - if (validatedLink.expiresAt) { - const now = new Date().getTime(); - const expiresAt = new Date(validatedLink.expiresAt).getTime(); - - if (now > expiresAt) { - throw new UnauthorizedException('Expired link'); - } - } - - const user = await this.usersRepository.findOne({ where: { id: validatedLink.userId } }); - - if (!user) { - throw new UnauthorizedException('Failure to validate public share payload'); - } - - let publicUser = new AuthUserDto(); - publicUser = user; - publicUser.isPublicUser = true; - publicUser.sharedLinkId = validatedLink.id; - publicUser.isAllowUpload = validatedLink.allowUpload; - - return publicUser; + return this.shareService.validate(key); } }