1
0
Fork 0
mirror of https://github.com/immich-app/immich.git synced 2025-01-21 03:02:44 +01:00

refactor(server): use UserService (#1309)

* refactor: communication gateway

* refactor: share strategy

* refactor: communication module
This commit is contained in:
Jason Rasmussen 2023-01-12 21:15:45 -05:00 committed by GitHub
parent 755a1331da
commit 92ca447f33
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 113 additions and 148 deletions

View file

@ -1,70 +1,32 @@
import { OnGatewayConnection, OnGatewayDisconnect, WebSocketGateway, WebSocketServer } from '@nestjs/websockets';
import { Socket, Server } from 'socket.io';
import { ImmichJwtService, JwtValidationResult } from '../../modules/immich-jwt/immich-jwt.service';
import { Logger } from '@nestjs/common'; import { Logger } from '@nestjs/common';
import { InjectRepository } from '@nestjs/typeorm'; import { OnGatewayConnection, OnGatewayDisconnect, WebSocketGateway, WebSocketServer } from '@nestjs/websockets';
import { UserEntity } from '@app/infra'; import { Server, Socket } from 'socket.io';
import { Repository } from 'typeorm'; import { ImmichJwtService } from '../../modules/immich-jwt/immich-jwt.service';
import cookieParser from 'cookie';
import { IMMICH_ACCESS_COOKIE } from '../../constants/jwt.constant';
@WebSocketGateway({ cors: true }) @WebSocketGateway({ cors: true })
export class CommunicationGateway implements OnGatewayConnection, OnGatewayDisconnect { export class CommunicationGateway implements OnGatewayConnection, OnGatewayDisconnect {
constructor( private logger = new Logger(CommunicationGateway.name);
private immichJwtService: ImmichJwtService,
@InjectRepository(UserEntity) constructor(private immichJwtService: ImmichJwtService) {}
private userRepository: Repository<UserEntity>,
) {}
@WebSocketServer() server!: Server; @WebSocketServer() server!: Server;
handleDisconnect(client: Socket) { handleDisconnect(client: Socket) {
client.leave(client.nsp.name); client.leave(client.nsp.name);
this.logger.log(`Client ${client.id} disconnected from Websocket`);
Logger.log(`Client ${client.id} disconnected from Websocket`, 'WebsocketConnectionEvent');
} }
async handleConnection(client: Socket) { async handleConnection(client: Socket) {
try { try {
Logger.log(`New websocket connection: ${client.id}`, 'WebsocketConnectionEvent'); this.logger.log(`New websocket connection: ${client.id}`);
let accessToken = '';
if (client.handshake.headers.cookie != undefined) { const user = await this.immichJwtService.validateSocket(client);
const cookies = cookieParser.parse(client.handshake.headers.cookie); if (user) {
if (cookies[IMMICH_ACCESS_COOKIE]) { client.join(user.id);
accessToken = cookies[IMMICH_ACCESS_COOKIE];
} else {
client.emit('error', 'unauthorized');
client.disconnect();
return;
}
} else if (client.handshake.headers.authorization != undefined) {
accessToken = client.handshake.headers.authorization.split(' ')[1];
} else { } else {
client.emit('error', 'unauthorized'); client.emit('error', 'unauthorized');
client.disconnect(); client.disconnect();
return;
} }
const res: JwtValidationResult = accessToken
? await this.immichJwtService.validateToken(accessToken)
: { status: false, userId: null };
if (!res.status || res.userId == null) {
client.emit('error', 'unauthorized');
client.disconnect();
return;
}
const user = await this.userRepository.findOne({ where: { id: res.userId } });
if (!user) {
client.emit('error', 'unauthorized');
client.disconnect();
return;
}
client.join(user.id);
} catch (e) { } catch (e) {
// Logger.error(`Error establish websocket conneciton ${e}`, 'HandleWebscoketConnection'); // Logger.error(`Error establish websocket conneciton ${e}`, 'HandleWebscoketConnection');
} }

View file

@ -1,16 +1,10 @@
import { Module } from '@nestjs/common'; import { Module } from '@nestjs/common';
import { CommunicationService } from './communication.service';
import { CommunicationGateway } from './communication.gateway'; import { CommunicationGateway } from './communication.gateway';
import { ImmichJwtModule } from '../../modules/immich-jwt/immich-jwt.module'; import { ImmichJwtModule } from '../../modules/immich-jwt/immich-jwt.module';
import { ImmichJwtService } from '../../modules/immich-jwt/immich-jwt.service';
import { JwtModule } from '@nestjs/jwt';
import { jwtConfig } from '../../config/jwt.config';
import { TypeOrmModule } from '@nestjs/typeorm';
import { UserEntity } from '@app/infra';
@Module({ @Module({
imports: [TypeOrmModule.forFeature([UserEntity]), ImmichJwtModule, JwtModule.register(jwtConfig)], imports: [ImmichJwtModule],
providers: [CommunicationGateway, CommunicationService, ImmichJwtService], providers: [CommunicationGateway],
exports: [CommunicationGateway], exports: [CommunicationGateway],
}) })
export class CommunicationModule {} export class CommunicationModule {}

View file

@ -1,4 +0,0 @@
import { Injectable } from '@nestjs/common';
@Injectable()
export class CommunicationService {}

View file

@ -32,7 +32,7 @@ export class ShareCore {
} }
} }
async getSharedLinks(userId: string): Promise<SharedLinkEntity[]> { getSharedLinks(userId: string): Promise<SharedLinkEntity[]> {
return this.sharedLinkRepository.get(userId); return this.sharedLinkRepository.get(userId);
} }
@ -46,27 +46,19 @@ export class ShareCore {
return await this.sharedLinkRepository.remove(link); return await this.sharedLinkRepository.remove(link);
} }
async getSharedLinkById(id: string): Promise<SharedLinkEntity> { getSharedLinkById(id: string): Promise<SharedLinkEntity | null> {
const link = await this.sharedLinkRepository.getById(id); return this.sharedLinkRepository.getById(id);
if (!link) {
throw new BadRequestException('Shared link not found');
}
return link;
} }
async getSharedLinkByKey(key: string): Promise<SharedLinkEntity> { getSharedLinkByKey(key: string): Promise<SharedLinkEntity | null> {
const link = await this.sharedLinkRepository.getByKey(key); return this.sharedLinkRepository.getByKey(key);
if (!link) {
throw new BadRequestException();
}
return link;
} }
async updateAssetsInSharedLink(sharedLinkId: string, assets: AssetEntity[]) { async updateAssetsInSharedLink(sharedLinkId: string, assets: AssetEntity[]) {
const link = await this.getSharedLinkById(sharedLinkId); const link = await this.getSharedLinkById(sharedLinkId);
if (!link) {
throw new BadRequestException('Shared link not found');
}
link.assets = assets; link.assets = assets;

View file

@ -1,4 +1,12 @@
import { ForbiddenException, Inject, Injectable, Logger } from '@nestjs/common'; import {
BadRequestException,
ForbiddenException,
Inject,
Injectable,
Logger,
UnauthorizedException,
} from '@nestjs/common';
import { UserService } from '@app/domain';
import { AuthUserDto } from '../../decorators/auth-user.decorator'; import { AuthUserDto } from '../../decorators/auth-user.decorator';
import { EditSharedLinkDto } from './dto/edit-shared-link.dto'; import { EditSharedLinkDto } from './dto/edit-shared-link.dto';
import { mapSharedLinkToResponseDto, SharedLinkResponseDto } from './response-dto/shared-link-response.dto'; import { mapSharedLinkToResponseDto, SharedLinkResponseDto } from './response-dto/shared-link-response.dto';
@ -13,9 +21,31 @@ export class ShareService {
constructor( constructor(
@Inject(ISharedLinkRepository) @Inject(ISharedLinkRepository)
sharedLinkRepository: ISharedLinkRepository, sharedLinkRepository: ISharedLinkRepository,
private userService: UserService,
) { ) {
this.shareCore = new ShareCore(sharedLinkRepository); this.shareCore = new ShareCore(sharedLinkRepository);
} }
async validate(key: string): Promise<AuthUserDto> {
const link = await this.shareCore.getSharedLinkByKey(key);
if (link) {
if (!link.expiresAt || new Date(link.expiresAt) > new Date()) {
const user = await this.userService.getUserById(link.userId).catch(() => null);
if (user) {
return {
id: user.id,
email: user.email,
isAdmin: user.isAdmin,
isPublicUser: true,
sharedLinkId: link.id,
isAllowUpload: link.allowUpload,
};
}
}
}
throw new UnauthorizedException();
}
async getAll(authUser: AuthUserDto): Promise<SharedLinkResponseDto[]> { async getAll(authUser: AuthUserDto): Promise<SharedLinkResponseDto[]> {
const links = await this.shareCore.getSharedLinks(authUser.id); const links = await this.shareCore.getSharedLinks(authUser.id);
return links.map(mapSharedLinkToResponseDto); return links.map(mapSharedLinkToResponseDto);
@ -26,13 +56,14 @@ export class ShareService {
throw new ForbiddenException(); throw new ForbiddenException();
} }
const link = await this.shareCore.getSharedLinkById(authUser.sharedLinkId); return this.getById(authUser.sharedLinkId);
return mapSharedLinkToResponseDto(link);
} }
async getById(id: string): Promise<SharedLinkResponseDto> { async getById(id: string): Promise<SharedLinkResponseDto> {
const link = await this.shareCore.getSharedLinkById(id); const link = await this.shareCore.getSharedLinkById(id);
if (!link) {
throw new BadRequestException('Shared link not found');
}
return mapSharedLinkToResponseDto(link); return mapSharedLinkToResponseDto(link);
} }
@ -43,12 +74,14 @@ export class ShareService {
async getByKey(key: string): Promise<SharedLinkResponseDto> { async getByKey(key: string): Promise<SharedLinkResponseDto> {
const link = await this.shareCore.getSharedLinkByKey(key); const link = await this.shareCore.getSharedLinkByKey(key);
if (!link) {
throw new BadRequestException('Shared link not found');
}
return mapSharedLinkToResponseDto(link); return mapSharedLinkToResponseDto(link);
} }
async edit(id: string, authUser: AuthUserDto, dto: EditSharedLinkDto) { async edit(id: string, authUser: AuthUserDto, dto: EditSharedLinkDto) {
const link = await this.shareCore.updateSharedLink(id, authUser.id, dto); const link = await this.shareCore.updateSharedLink(id, authUser.id, dto);
return mapSharedLinkToResponseDto(link); return mapSharedLinkToResponseDto(link);
} }
} }

View file

@ -3,15 +3,13 @@ import { ImmichJwtService } from './immich-jwt.service';
import { JwtModule } from '@nestjs/jwt'; import { JwtModule } from '@nestjs/jwt';
import { jwtConfig } from '../../config/jwt.config'; import { jwtConfig } from '../../config/jwt.config';
import { JwtStrategy } from './strategies/jwt.strategy'; import { JwtStrategy } from './strategies/jwt.strategy';
import { TypeOrmModule } from '@nestjs/typeorm';
import { UserEntity } from '@app/infra';
import { APIKeyModule } from '../../api-v1/api-key/api-key.module'; import { APIKeyModule } from '../../api-v1/api-key/api-key.module';
import { APIKeyStrategy } from './strategies/api-key.strategy'; import { APIKeyStrategy } from './strategies/api-key.strategy';
import { ShareModule } from '../../api-v1/share/share.module'; import { ShareModule } from '../../api-v1/share/share.module';
import { PublicShareStrategy } from './strategies/public-share.strategy'; import { PublicShareStrategy } from './strategies/public-share.strategy';
@Module({ @Module({
imports: [JwtModule.register(jwtConfig), TypeOrmModule.forFeature([UserEntity]), APIKeyModule, ShareModule], imports: [JwtModule.register(jwtConfig), APIKeyModule, ShareModule],
providers: [ImmichJwtService, JwtStrategy, APIKeyStrategy, PublicShareStrategy], providers: [ImmichJwtService, JwtStrategy, APIKeyStrategy, PublicShareStrategy],
exports: [ImmichJwtService], exports: [ImmichJwtService],
}) })

View file

@ -5,9 +5,11 @@ import { UserEntity } from '@app/infra';
import { LoginResponseDto } from '../../api-v1/auth/response-dto/login-response.dto'; import { LoginResponseDto } from '../../api-v1/auth/response-dto/login-response.dto';
import { AuthType } from '../../constants/jwt.constant'; import { AuthType } from '../../constants/jwt.constant';
import { ImmichJwtService } from './immich-jwt.service'; import { ImmichJwtService } from './immich-jwt.service';
import { UserService } from '@app/domain';
describe('ImmichJwtService', () => { describe('ImmichJwtService', () => {
let jwtServiceMock: jest.Mocked<JwtService>; let jwtServiceMock: jest.Mocked<JwtService>;
let userServiceMock: jest.Mocked<UserService>;
let sut: ImmichJwtService; let sut: ImmichJwtService;
beforeEach(() => { beforeEach(() => {
@ -16,7 +18,11 @@ describe('ImmichJwtService', () => {
verifyAsync: jest.fn(), verifyAsync: jest.fn(),
} as unknown as jest.Mocked<JwtService>; } as unknown as jest.Mocked<JwtService>;
sut = new ImmichJwtService(jwtServiceMock); userServiceMock = {
getUserById: jest.fn(),
} as unknown as jest.Mocked<UserService>;
sut = new ImmichJwtService(jwtServiceMock, userServiceMock);
}); });
afterEach(() => { afterEach(() => {
@ -102,7 +108,7 @@ describe('ImmichJwtService', () => {
const request = { const request = {
headers: {}, headers: {},
} as Request; } as Request;
const token = sut.extractJwtFromHeader(request); const token = sut.extractJwtFromHeader(request.headers);
expect(token).toBe(null); expect(token).toBe(null);
}); });
@ -119,15 +125,15 @@ describe('ImmichJwtService', () => {
}, },
} as Request; } as Request;
expect(sut.extractJwtFromHeader(upper)).toBe('token'); expect(sut.extractJwtFromHeader(upper.headers)).toBe('token');
expect(sut.extractJwtFromHeader(lower)).toBe('token'); expect(sut.extractJwtFromHeader(lower.headers)).toBe('token');
}); });
}); });
describe('extracJwtFromCookie', () => { describe('extracJwtFromCookie', () => {
it('should handle no cookie', () => { it('should handle no cookie', () => {
const request = {} as Request; const request = {} as Request;
const token = sut.extractJwtFromCookie(request); const token = sut.extractJwtFromCookie(request.cookies);
expect(token).toBe(null); expect(token).toBe(null);
}); });
@ -137,7 +143,7 @@ describe('ImmichJwtService', () => {
immich_access_token: 'cookie', immich_access_token: 'cookie',
}, },
} as Request; } as Request;
const token = sut.extractJwtFromCookie(request); const token = sut.extractJwtFromCookie(request.cookies);
expect(token).toBe('cookie'); expect(token).toBe('cookie');
}); });
}); });

View file

@ -1,10 +1,13 @@
import { UserEntity } from '@app/infra'; import { UserEntity } from '@app/infra';
import { Injectable, Logger } from '@nestjs/common'; import { Injectable, Logger } from '@nestjs/common';
import { JwtService } from '@nestjs/jwt'; import { JwtService } from '@nestjs/jwt';
import { Request } from 'express'; import { IncomingHttpHeaders } from 'http';
import { JwtPayloadDto } from '../../api-v1/auth/dto/jwt-payload.dto'; import { JwtPayloadDto } from '../../api-v1/auth/dto/jwt-payload.dto';
import { LoginResponseDto, mapLoginResponse } from '../../api-v1/auth/response-dto/login-response.dto'; import { LoginResponseDto, mapLoginResponse } from '../../api-v1/auth/response-dto/login-response.dto';
import { AuthType, IMMICH_ACCESS_COOKIE, IMMICH_AUTH_TYPE_COOKIE, jwtSecret } from '../../constants/jwt.constant'; import { AuthType, IMMICH_ACCESS_COOKIE, IMMICH_AUTH_TYPE_COOKIE, jwtSecret } from '../../constants/jwt.constant';
import { Socket } from 'socket.io';
import cookieParser from 'cookie';
import { UserResponseDto, UserService } from '@app/domain';
export type JwtValidationResult = { export type JwtValidationResult = {
status: boolean; status: boolean;
@ -13,7 +16,7 @@ export type JwtValidationResult = {
@Injectable() @Injectable()
export class ImmichJwtService { export class ImmichJwtService {
constructor(private jwtService: JwtService) {} constructor(private jwtService: JwtService, private userService: UserService) {}
public getCookieNames() { public getCookieNames() {
return [IMMICH_ACCESS_COOKIE, IMMICH_AUTH_TYPE_COOKIE]; return [IMMICH_ACCESS_COOKIE, IMMICH_AUTH_TYPE_COOKIE];
@ -51,22 +54,40 @@ export class ImmichJwtService {
} }
} }
public extractJwtFromHeader(req: Request) { public extractJwtFromHeader(headers: IncomingHttpHeaders) {
if ( if (!headers.authorization) {
req.headers.authorization && return null;
(req.headers.authorization.split(' ')[0] === 'Bearer' || req.headers.authorization.split(' ')[0] === 'bearer') }
) { const [type, accessToken] = headers.authorization.split(' ');
const accessToken = req.headers.authorization.split(' ')[1]; if (type.toLowerCase() !== 'bearer') {
return accessToken; return null;
}
return accessToken;
}
public extractJwtFromCookie(cookies: Record<string, string>) {
return cookies?.[IMMICH_ACCESS_COOKIE] || null;
}
public async validateSocket(client: Socket): Promise<UserResponseDto | null> {
const headers = client.handshake.headers;
const accessToken =
this.extractJwtFromCookie(cookieParser.parse(headers.cookie || '')) || this.extractJwtFromHeader(headers);
if (accessToken) {
const { userId, status } = await this.validateToken(accessToken);
if (userId && status) {
const user = await this.userService.getUserById(userId).catch(() => null);
if (user) {
return user;
}
}
} }
return null; return null;
} }
public extractJwtFromCookie(req: Request) {
return req.cookies?.[IMMICH_ACCESS_COOKIE] || null;
}
private async generateToken(payload: JwtPayloadDto) { private async generateToken(payload: JwtPayloadDto) {
return this.jwtService.sign({ return this.jwtService.sign({
...payload, ...payload,

View file

@ -1,9 +1,7 @@
import { UserEntity } from '@app/infra';
import { Injectable, UnauthorizedException } from '@nestjs/common'; import { Injectable, UnauthorizedException } from '@nestjs/common';
import { PassportStrategy } from '@nestjs/passport'; import { PassportStrategy } from '@nestjs/passport';
import { InjectRepository } from '@nestjs/typeorm';
import { ExtractJwt, Strategy, StrategyOptions } from 'passport-jwt'; import { ExtractJwt, Strategy, StrategyOptions } from 'passport-jwt';
import { Repository } from 'typeorm'; import { UserService } from '@app/domain';
import { JwtPayloadDto } from '../../../api-v1/auth/dto/jwt-payload.dto'; import { JwtPayloadDto } from '../../../api-v1/auth/dto/jwt-payload.dto';
import { jwtSecret } from '../../../constants/jwt.constant'; import { jwtSecret } from '../../../constants/jwt.constant';
import { AuthUserDto } from '../../../decorators/auth-user.decorator'; import { AuthUserDto } from '../../../decorators/auth-user.decorator';
@ -13,15 +11,11 @@ export const JWT_STRATEGY = 'jwt';
@Injectable() @Injectable()
export class JwtStrategy extends PassportStrategy(Strategy, JWT_STRATEGY) { export class JwtStrategy extends PassportStrategy(Strategy, JWT_STRATEGY) {
constructor( constructor(private userService: UserService, immichJwtService: ImmichJwtService) {
@InjectRepository(UserEntity)
private usersRepository: Repository<UserEntity>,
immichJwtService: ImmichJwtService,
) {
super({ super({
jwtFromRequest: ExtractJwt.fromExtractors([ jwtFromRequest: ExtractJwt.fromExtractors([
immichJwtService.extractJwtFromCookie, (req) => immichJwtService.extractJwtFromCookie(req.cookies),
immichJwtService.extractJwtFromHeader, (req) => immichJwtService.extractJwtFromHeader(req.headers),
]), ]),
ignoreExpiration: false, ignoreExpiration: false,
secretOrKey: jwtSecret, secretOrKey: jwtSecret,
@ -30,8 +24,7 @@ export class JwtStrategy extends PassportStrategy(Strategy, JWT_STRATEGY) {
async validate(payload: JwtPayloadDto): Promise<AuthUserDto> { async validate(payload: JwtPayloadDto): Promise<AuthUserDto> {
const { userId } = payload; const { userId } = payload;
const user = await this.usersRepository.findOne({ where: { id: userId } }); const user = await this.userService.getUserById(userId).catch(() => null);
if (!user) { if (!user) {
throw new UnauthorizedException('Failure to validate JWT payload'); throw new UnauthorizedException('Failure to validate JWT payload');
} }

View file

@ -1,9 +1,6 @@
import { UserEntity } from '@app/infra'; import { Injectable } from '@nestjs/common';
import { Injectable, UnauthorizedException } from '@nestjs/common';
import { PassportStrategy } from '@nestjs/passport'; import { PassportStrategy } from '@nestjs/passport';
import { InjectRepository } from '@nestjs/typeorm';
import { IStrategyOptions, Strategy } from 'passport-http-header-strategy'; import { IStrategyOptions, Strategy } from 'passport-http-header-strategy';
import { Repository } from 'typeorm';
import { ShareService } from '../../../api-v1/share/share.service'; import { ShareService } from '../../../api-v1/share/share.service';
import { AuthUserDto } from '../../../decorators/auth-user.decorator'; import { AuthUserDto } from '../../../decorators/auth-user.decorator';
@ -16,38 +13,11 @@ const options: IStrategyOptions = {
@Injectable() @Injectable()
export class PublicShareStrategy extends PassportStrategy(Strategy, PUBLIC_SHARE_STRATEGY) { export class PublicShareStrategy extends PassportStrategy(Strategy, PUBLIC_SHARE_STRATEGY) {
constructor( constructor(private shareService: ShareService) {
private shareService: ShareService,
@InjectRepository(UserEntity)
private usersRepository: Repository<UserEntity>,
) {
super(options); super(options);
} }
async validate(key: string): Promise<AuthUserDto> { async validate(key: string): Promise<AuthUserDto> {
const validatedLink = await this.shareService.getByKey(key); return this.shareService.validate(key);
if (validatedLink.expiresAt) {
const now = new Date().getTime();
const expiresAt = new Date(validatedLink.expiresAt).getTime();
if (now > expiresAt) {
throw new UnauthorizedException('Expired link');
}
}
const user = await this.usersRepository.findOne({ where: { id: validatedLink.userId } });
if (!user) {
throw new UnauthorizedException('Failure to validate public share payload');
}
let publicUser = new AuthUserDto();
publicUser = user;
publicUser.isPublicUser = true;
publicUser.sharedLinkId = validatedLink.id;
publicUser.isAllowUpload = validatedLink.allowUpload;
return publicUser;
} }
} }