mirror of
https://github.com/immich-app/immich.git
synced 2024-12-29 15:11:58 +00:00
feat(ml): introduce support of onnxruntime-rocm for AMD GPU
This commit is contained in:
parent
79a780e8d9
commit
46c505a592
14 changed files with 270 additions and 76 deletions
36
.github/workflows/docker.yml
vendored
36
.github/workflows/docker.yml
vendored
|
@ -48,21 +48,21 @@ jobs:
|
|||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
suffix: ["", "-cuda", "-openvino", "-armnn"]
|
||||
suffix: ['', '-cuda', '-openvino', '-armnn']
|
||||
steps:
|
||||
- name: Login to GitHub Container Registry
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: ghcr.io
|
||||
username: ${{ github.repository_owner }}
|
||||
password: ${{ secrets.GITHUB_TOKEN }}
|
||||
- name: Re-tag image
|
||||
run: |
|
||||
REGISTRY_NAME="ghcr.io"
|
||||
REPOSITORY=${{ github.repository_owner }}/immich-machine-learning
|
||||
TAG_OLD=main${{ matrix.suffix }}
|
||||
TAG_NEW=${{ github.event.number == 0 && github.ref_name || format('pr-{0}', github.event.number) }}${{ matrix.suffix }}
|
||||
docker buildx imagetools create -t $REGISTRY_NAME/$REPOSITORY:$TAG_NEW $REGISTRY_NAME/$REPOSITORY:$TAG_OLD
|
||||
- name: Login to GitHub Container Registry
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: ghcr.io
|
||||
username: ${{ github.repository_owner }}
|
||||
password: ${{ secrets.GITHUB_TOKEN }}
|
||||
- name: Re-tag image
|
||||
run: |
|
||||
REGISTRY_NAME="ghcr.io"
|
||||
REPOSITORY=${{ github.repository_owner }}/immich-machine-learning
|
||||
TAG_OLD=main${{ matrix.suffix }}
|
||||
TAG_NEW=${{ github.event.number == 0 && github.ref_name || format('pr-{0}', github.event.number) }}${{ matrix.suffix }}
|
||||
docker buildx imagetools create -t $REGISTRY_NAME/$REPOSITORY:$TAG_NEW $REGISTRY_NAME/$REPOSITORY:$TAG_OLD
|
||||
|
||||
retag_server:
|
||||
name: Re-Tag Server
|
||||
|
@ -71,7 +71,7 @@ jobs:
|
|||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
suffix: [""]
|
||||
suffix: ['']
|
||||
steps:
|
||||
- name: Login to GitHub Container Registry
|
||||
uses: docker/login-action@v3
|
||||
|
@ -87,7 +87,6 @@ jobs:
|
|||
TAG_NEW=${{ github.event.number == 0 && github.ref_name || format('pr-{0}', github.event.number) }}${{ matrix.suffix }}
|
||||
docker buildx imagetools create -t $REGISTRY_NAME/$REPOSITORY:$TAG_NEW $REGISTRY_NAME/$REPOSITORY:$TAG_OLD
|
||||
|
||||
|
||||
build_and_push_ml:
|
||||
name: Build and Push ML
|
||||
needs: pre-job
|
||||
|
@ -109,6 +108,10 @@ jobs:
|
|||
device: cuda
|
||||
suffix: -cuda
|
||||
|
||||
- platforms: linux/amd64
|
||||
device: rocm
|
||||
suffix: -rocm
|
||||
|
||||
- platforms: linux/amd64
|
||||
device: openvino
|
||||
suffix: -openvino
|
||||
|
@ -192,7 +195,6 @@ jobs:
|
|||
BUILD_SOURCE_REF=${{ github.ref_name }}
|
||||
BUILD_SOURCE_COMMIT=${{ github.sha }}
|
||||
|
||||
|
||||
build_and_push_server:
|
||||
name: Build and Push Server
|
||||
runs-on: ubuntu-latest
|
||||
|
|
|
@ -85,12 +85,12 @@ services:
|
|||
image: immich-machine-learning-dev:latest
|
||||
# extends:
|
||||
# file: hwaccel.ml.yml
|
||||
# service: cpu # set to one of [armnn, cuda, openvino, openvino-wsl] for accelerated inference
|
||||
# service: cpu # set to one of [armnn, cuda, rocm, openvino, openvino-wsl] for accelerated inference
|
||||
build:
|
||||
context: ../machine-learning
|
||||
dockerfile: Dockerfile
|
||||
args:
|
||||
- DEVICE=cpu # set to one of [armnn, cuda, openvino, openvino-wsl] for accelerated inference
|
||||
- DEVICE=cpu # set to one of [armnn, cuda, rocm, openvino, openvino-wsl] for accelerated inference
|
||||
ports:
|
||||
- 3003:3003
|
||||
volumes:
|
||||
|
|
|
@ -29,12 +29,12 @@ services:
|
|||
image: immich-machine-learning:latest
|
||||
# extends:
|
||||
# file: hwaccel.ml.yml
|
||||
# service: cpu # set to one of [armnn, cuda, openvino, openvino-wsl] for accelerated inference
|
||||
# service: cpu # set to one of [armnn, cuda, rocm, openvino, openvino-wsl] for accelerated inference
|
||||
build:
|
||||
context: ../machine-learning
|
||||
dockerfile: Dockerfile
|
||||
args:
|
||||
- DEVICE=cpu # set to one of [armnn, cuda, openvino, openvino-wsl] for accelerated inference
|
||||
- DEVICE=cpu # set to one of [armnn, cuda, rocm, openvino, openvino-wsl] for accelerated inference
|
||||
ports:
|
||||
- 3003:3003
|
||||
volumes:
|
||||
|
|
|
@ -32,12 +32,12 @@ services:
|
|||
|
||||
immich-machine-learning:
|
||||
container_name: immich_machine_learning
|
||||
# For hardware acceleration, add one of -[armnn, cuda, openvino] to the image tag.
|
||||
# For hardware acceleration, add one of -[armnn, cuda, rocm, openvino] to the image tag.
|
||||
# Example tag: ${IMMICH_VERSION:-release}-cuda
|
||||
image: ghcr.io/immich-app/immich-machine-learning:${IMMICH_VERSION:-release}
|
||||
# extends: # uncomment this section for hardware acceleration - see https://immich.app/docs/features/ml-hardware-acceleration
|
||||
# file: hwaccel.ml.yml
|
||||
# service: cpu # set to one of [armnn, cuda, openvino, openvino-wsl] for accelerated inference - use the `-wsl` version for WSL2 where applicable
|
||||
# service: cpu # set to one of [armnn, cuda, rocm, openvino, openvino-wsl] for accelerated inference - use the `-wsl` version for WSL2 where applicable
|
||||
volumes:
|
||||
- model-cache:/cache
|
||||
env_file:
|
||||
|
|
|
@ -26,6 +26,13 @@ services:
|
|||
capabilities:
|
||||
- gpu
|
||||
|
||||
rocm:
|
||||
group_add:
|
||||
- video
|
||||
devices:
|
||||
- /dev/dri:/dev/dri
|
||||
- /dev/kfd:/dev/kfd
|
||||
|
||||
openvino:
|
||||
device_cgroup_rules:
|
||||
- 'c 189:* rmw'
|
||||
|
|
|
@ -11,6 +11,7 @@ You do not need to redo any machine learning jobs after enabling hardware accele
|
|||
|
||||
- ARM NN (Mali)
|
||||
- CUDA (NVIDIA GPUs with [compute capability](https://developer.nvidia.com/cuda-gpus) 5.2 or higher)
|
||||
- ROCM (AMD GPUs)
|
||||
- OpenVINO (Intel discrete GPUs such as Iris Xe and Arc)
|
||||
|
||||
## Limitations
|
||||
|
@ -41,6 +42,10 @@ You do not need to redo any machine learning jobs after enabling hardware accele
|
|||
- The installed driver must be >= 535 (it must support CUDA 12.2).
|
||||
- On Linux (except for WSL2), you also need to have [NVIDIA Container Toolkit][nvct] installed.
|
||||
|
||||
#### ROCM
|
||||
|
||||
- The GPU must be supported by ROCM (or use `HSA_OVERRIDE_GFX_VERSION=<a supported version, ie 10.3.0>`)
|
||||
|
||||
#### OpenVINO
|
||||
|
||||
- The server must have a discrete GPU, i.e. Iris Xe or Arc. Expect issues when attempting to use integrated graphics.
|
||||
|
@ -50,12 +55,12 @@ You do not need to redo any machine learning jobs after enabling hardware accele
|
|||
|
||||
1. If you do not already have it, download the latest [`hwaccel.ml.yml`][hw-file] file and ensure it's in the same folder as the `docker-compose.yml`.
|
||||
2. In the `docker-compose.yml` under `immich-machine-learning`, uncomment the `extends` section and change `cpu` to the appropriate backend.
|
||||
3. Still in `immich-machine-learning`, add one of -[armnn, cuda, openvino] to the `image` section's tag at the end of the line.
|
||||
3. Still in `immich-machine-learning`, add one of -[armnn, cuda, rocm, openvino] to the `image` section's tag at the end of the line.
|
||||
4. Redeploy the `immich-machine-learning` container with these updated settings.
|
||||
|
||||
### Confirming Device Usage
|
||||
|
||||
You can confirm the device is being recognized and used by checking its utilization. There are many tools to display this, such as `nvtop` for NVIDIA or Intel and `intel_gpu_top` for Intel.
|
||||
You can confirm the device is being recognized and used by checking its utilization. There are many tools to display this, such as `nvtop` for NVIDIA or Intel, `intel_gpu_top` for Intel, and `radeontop` for AMD.
|
||||
|
||||
You can also check the logs of the `immich-machine-learning` container. When a Smart Search or Face Detection job begins, or when you search with text in Immich, you should either see a log for `Available ORT providers` containing the relevant provider (e.g. `CUDAExecutionProvider` in the case of CUDA), or a `Loaded ANN model` log entry without errors in the case of ARM NN.
|
||||
|
||||
|
|
|
@ -23,12 +23,12 @@ name: immich_remote_ml
|
|||
services:
|
||||
immich-machine-learning:
|
||||
container_name: immich_machine_learning
|
||||
# For hardware acceleration, add one of -[armnn, cuda, openvino] to the image tag.
|
||||
# For hardware acceleration, add one of -[armnn, cuda, rocm, openvino] to the image tag.
|
||||
# Example tag: ${IMMICH_VERSION:-release}-cuda
|
||||
image: ghcr.io/immich-app/immich-machine-learning:${IMMICH_VERSION:-release}
|
||||
# extends:
|
||||
# file: hwaccel.ml.yml
|
||||
# service: # set to one of [armnn, cuda, openvino, openvino-wsl] for accelerated inference - use the `-wsl` version for WSL2 where applicable
|
||||
# service: # set to one of [armnn, cuda, rocm, openvino, openvino-wsl] for accelerated inference - use the `-wsl` version for WSL2 where applicable
|
||||
volumes:
|
||||
- model-cache:/cache
|
||||
restart: always
|
||||
|
|
|
@ -15,6 +15,40 @@ RUN mkdir /opt/armnn && \
|
|||
cd /opt/ann && \
|
||||
sh build.sh
|
||||
|
||||
# Warning: 26.3Gb of disk space required to pull this image
|
||||
# https://github.com/microsoft/onnxruntime/blob/main/dockerfiles/Dockerfile.rocm
|
||||
FROM rocm/dev-ubuntu-22.04:6.1.2-complete as builder-rocm
|
||||
|
||||
WORKDIR /code
|
||||
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends wget git python3.10-venv
|
||||
# Install same version as the Dockerfile provided by onnxruntime
|
||||
RUN wget https://github.com/Kitware/CMake/releases/download/v3.27.3/cmake-3.27.3-linux-x86_64.sh && \
|
||||
chmod +x cmake-3.27.3-linux-x86_64.sh && \
|
||||
mkdir -p /code/cmake-3.27.3-linux-x86_64 && \
|
||||
./cmake-3.27.3-linux-x86_64.sh --skip-license --prefix=/code/cmake-3.27.3-linux-x86_64 && \
|
||||
rm cmake-3.27.3-linux-x86_64.sh
|
||||
|
||||
ENV PATH /code/cmake-3.27.3-linux-x86_64/bin:${PATH}
|
||||
|
||||
# Prepare onnxruntime repository & build onnxruntime
|
||||
RUN git clone --single-branch --branch v1.18.1 --recursive "https://github.com/Microsoft/onnxruntime" onnxruntime
|
||||
WORKDIR /code/onnxruntime
|
||||
# EDIT PR
|
||||
# While there's still this PR open, we need to compile on the branch of the PR
|
||||
# https://github.com/microsoft/onnxruntime/pull/19567
|
||||
COPY ./rocm-PR19567.patch /tmp/
|
||||
RUN git apply /tmp/rocm-PR19567.patch
|
||||
# END EDIT PR
|
||||
RUN /bin/sh ./dockerfiles/scripts/install_common_deps.sh
|
||||
# I ran into a compilation error when parallelizing the build
|
||||
# I used 12 threads to build onnxruntime, but it needs more than 16GB of RAM, and that's the amount of RAM I have on my machine
|
||||
# I lowered the number of threads to 8, and it worked
|
||||
# Even with 12 threads, the compilation took more than 1,5 hours to fail
|
||||
RUN ./build.sh --allow_running_as_root --config Release --build_wheel --update --build --parallel 9 --cmake_extra_defines\
|
||||
ONNXRUNTIME_VERSION=1.18.1 --use_rocm --rocm_home=/opt/rocm
|
||||
RUN mv /code/onnxruntime/build/Linux/Release/dist/*.whl /opt/
|
||||
|
||||
FROM builder-${DEVICE} AS builder
|
||||
|
||||
ARG DEVICE
|
||||
|
@ -32,6 +66,9 @@ RUN poetry config installer.max-workers 10 && \
|
|||
RUN python3 -m venv /opt/venv
|
||||
|
||||
COPY poetry.lock pyproject.toml ./
|
||||
RUN if [ "$DEVICE" = "rocm" ]; then \
|
||||
poetry add /opt/onnxruntime_rocm-*.whl; \
|
||||
fi
|
||||
RUN poetry install --sync --no-interaction --no-ansi --no-root --with ${DEVICE} --without dev
|
||||
|
||||
FROM python:3.11-slim-bookworm@sha256:370c586a6ffc8c619e6d652f81c094b34b14b8f2fb9251f092de23f16e299b78 AS prod-cpu
|
||||
|
@ -80,11 +117,15 @@ COPY --from=builder-armnn \
|
|||
/opt/ann/build.sh \
|
||||
/opt/armnn/
|
||||
|
||||
FROM rocm/dev-ubuntu-22.04:6.1.2-complete AS prod-rocm
|
||||
|
||||
|
||||
FROM prod-${DEVICE} AS prod
|
||||
|
||||
ARG DEVICE
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --no-install-recommends tini $(if ! [ "$DEVICE" = "openvino" ]; then echo "libmimalloc2.0"; fi) && \
|
||||
apt-get install -y --no-install-recommends tini $(if ! [ "$DEVICE" = "openvino" ] && ! [ "$DEVICE" = "rocm" ]; then echo "libmimalloc2.0"; fi) && \
|
||||
apt-get autoremove -yqq && \
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
|
||||
This project uses [Poetry](https://python-poetry.org/docs/#installation), so be sure to install it first.
|
||||
Running `poetry install --no-root --with dev --with cpu` will install everything you need in an isolated virtual environment.
|
||||
CUDA and OpenVINO are supported as acceleration APIs. To use them, you can replace `--with cpu` with either of `--with cuda` or `--with openvino`. In the case of CUDA, a [compute capability](https://developer.nvidia.com/cuda-gpus) of 5.2 or higher is required.
|
||||
CUDA, ROCM and OpenVINO are supported as acceleration APIs. To use them, you can replace `--with cpu` with either of `--with cuda`, `--with rocm` or `--with openvino`. In the case of CUDA, a [compute capability](https://developer.nvidia.com/cuda-gpus) of 5.2 or higher is required.
|
||||
|
||||
To add or remove dependencies, you can use the commands `poetry add $PACKAGE_NAME` and `poetry remove $PACKAGE_NAME`, respectively.
|
||||
Be sure to commit the `poetry.lock` and `pyproject.toml` files with `poetry lock --no-update` to reflect any changes in dependencies.
|
||||
|
@ -37,4 +37,4 @@ This project utilizes facial recognition models from the [InsightFace](https://g
|
|||
## License and Use Restrictions
|
||||
We have received permission to use the InsightFace facial recognition models in our project, as granted via email by Jia Guo (guojia@insightface.ai) on 18th March 2023. However, it's important to note that this permission does not extend to the redistribution or commercial use of their models by third parties. Users and developers interested in using these models should review the licensing terms provided in the InsightFace GitHub repository.
|
||||
|
||||
For more information on the capabilities of the InsightFace models and to ensure compliance with their license, please refer to their [official repository](https://github.com/deepinsight/insightface). Adhering to the specified licensing terms is crucial for the respectful and lawful use of their work.
|
||||
For more information on the capabilities of the InsightFace models and to ensure compliance with their license, please refer to their [official repository](https://github.com/deepinsight/insightface). Adhering to the specified licensing terms is crucial for the respectful and lawful use of their work.
|
||||
|
|
|
@ -63,7 +63,7 @@ _INSIGHTFACE_MODELS = {
|
|||
}
|
||||
|
||||
|
||||
SUPPORTED_PROVIDERS = ["CUDAExecutionProvider", "OpenVINOExecutionProvider", "CPUExecutionProvider"]
|
||||
SUPPORTED_PROVIDERS = ["CUDAExecutionProvider", "ROCMExecutionProvider", "OpenVINOExecutionProvider", "CPUExecutionProvider"]
|
||||
|
||||
|
||||
def get_model_source(model_name: str) -> ModelSource | None:
|
||||
|
|
|
@ -88,7 +88,7 @@ class OrtSession:
|
|||
match provider:
|
||||
case "CPUExecutionProvider":
|
||||
options = {"arena_extend_strategy": "kSameAsRequested"}
|
||||
case "CUDAExecutionProvider":
|
||||
case "CUDAExecutionProvider" | "ROCMExecutionProvider":
|
||||
options = {"arena_extend_strategy": "kSameAsRequested", "device_id": settings.device_id}
|
||||
case "OpenVINOExecutionProvider":
|
||||
options = {
|
||||
|
|
46
machine-learning/poetry.lock
generated
46
machine-learning/poetry.lock
generated
|
@ -1,4 +1,4 @@
|
|||
# This file is automatically @generated by Poetry 1.8.4 and should not be changed by hand.
|
||||
# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand.
|
||||
|
||||
[[package]]
|
||||
name = "aiocache"
|
||||
|
@ -147,10 +147,6 @@ files = [
|
|||
{file = "Brotli-1.1.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:a37b8f0391212d29b3a91a799c8e4a2855e0576911cdfb2515487e30e322253d"},
|
||||
{file = "Brotli-1.1.0-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:e84799f09591700a4154154cab9787452925578841a94321d5ee8fb9a9a328f0"},
|
||||
{file = "Brotli-1.1.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:f66b5337fa213f1da0d9000bc8dc0cb5b896b726eefd9c6046f699b169c41b9e"},
|
||||
{file = "Brotli-1.1.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:5dab0844f2cf82be357a0eb11a9087f70c5430b2c241493fc122bb6f2bb0917c"},
|
||||
{file = "Brotli-1.1.0-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:e4fe605b917c70283db7dfe5ada75e04561479075761a0b3866c081d035b01c1"},
|
||||
{file = "Brotli-1.1.0-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:1e9a65b5736232e7a7f91ff3d02277f11d339bf34099a56cdab6a8b3410a02b2"},
|
||||
{file = "Brotli-1.1.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:58d4b711689366d4a03ac7957ab8c28890415e267f9b6589969e74b6e42225ec"},
|
||||
{file = "Brotli-1.1.0-cp310-cp310-win32.whl", hash = "sha256:be36e3d172dc816333f33520154d708a2657ea63762ec16b62ece02ab5e4daf2"},
|
||||
{file = "Brotli-1.1.0-cp310-cp310-win_amd64.whl", hash = "sha256:0c6244521dda65ea562d5a69b9a26120769b7a9fb3db2fe9545935ed6735b128"},
|
||||
{file = "Brotli-1.1.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:a3daabb76a78f829cafc365531c972016e4aa8d5b4bf60660ad8ecee19df7ccc"},
|
||||
|
@ -163,14 +159,8 @@ files = [
|
|||
{file = "Brotli-1.1.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:19c116e796420b0cee3da1ccec3b764ed2952ccfcc298b55a10e5610ad7885f9"},
|
||||
{file = "Brotli-1.1.0-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:510b5b1bfbe20e1a7b3baf5fed9e9451873559a976c1a78eebaa3b86c57b4265"},
|
||||
{file = "Brotli-1.1.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:a1fd8a29719ccce974d523580987b7f8229aeace506952fa9ce1d53a033873c8"},
|
||||
{file = "Brotli-1.1.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:c247dd99d39e0338a604f8c2b3bc7061d5c2e9e2ac7ba9cc1be5a69cb6cd832f"},
|
||||
{file = "Brotli-1.1.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:1b2c248cd517c222d89e74669a4adfa5577e06ab68771a529060cf5a156e9757"},
|
||||
{file = "Brotli-1.1.0-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:2a24c50840d89ded6c9a8fdc7b6ed3692ed4e86f1c4a4a938e1e92def92933e0"},
|
||||
{file = "Brotli-1.1.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:f31859074d57b4639318523d6ffdca586ace54271a73ad23ad021acd807eb14b"},
|
||||
{file = "Brotli-1.1.0-cp311-cp311-win32.whl", hash = "sha256:39da8adedf6942d76dc3e46653e52df937a3c4d6d18fdc94a7c29d263b1f5b50"},
|
||||
{file = "Brotli-1.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:aac0411d20e345dc0920bdec5548e438e999ff68d77564d5e9463a7ca9d3e7b1"},
|
||||
{file = "Brotli-1.1.0-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:32d95b80260d79926f5fab3c41701dbb818fde1c9da590e77e571eefd14abe28"},
|
||||
{file = "Brotli-1.1.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:b760c65308ff1e462f65d69c12e4ae085cff3b332d894637f6273a12a482d09f"},
|
||||
{file = "Brotli-1.1.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:316cc9b17edf613ac76b1f1f305d2a748f1b976b033b049a6ecdfd5612c70409"},
|
||||
{file = "Brotli-1.1.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:caf9ee9a5775f3111642d33b86237b05808dafcd6268faa492250e9b78046eb2"},
|
||||
{file = "Brotli-1.1.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:70051525001750221daa10907c77830bc889cb6d865cc0b813d9db7fefc21451"},
|
||||
|
@ -181,24 +171,8 @@ files = [
|
|||
{file = "Brotli-1.1.0-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:4093c631e96fdd49e0377a9c167bfd75b6d0bad2ace734c6eb20b348bc3ea180"},
|
||||
{file = "Brotli-1.1.0-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:7e4c4629ddad63006efa0ef968c8e4751c5868ff0b1c5c40f76524e894c50248"},
|
||||
{file = "Brotli-1.1.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:861bf317735688269936f755fa136a99d1ed526883859f86e41a5d43c61d8966"},
|
||||
{file = "Brotli-1.1.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:87a3044c3a35055527ac75e419dfa9f4f3667a1e887ee80360589eb8c90aabb9"},
|
||||
{file = "Brotli-1.1.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:c5529b34c1c9d937168297f2c1fde7ebe9ebdd5e121297ff9c043bdb2ae3d6fb"},
|
||||
{file = "Brotli-1.1.0-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:ca63e1890ede90b2e4454f9a65135a4d387a4585ff8282bb72964fab893f2111"},
|
||||
{file = "Brotli-1.1.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:e79e6520141d792237c70bcd7a3b122d00f2613769ae0cb61c52e89fd3443839"},
|
||||
{file = "Brotli-1.1.0-cp312-cp312-win32.whl", hash = "sha256:5f4d5ea15c9382135076d2fb28dde923352fe02951e66935a9efaac8f10e81b0"},
|
||||
{file = "Brotli-1.1.0-cp312-cp312-win_amd64.whl", hash = "sha256:906bc3a79de8c4ae5b86d3d75a8b77e44404b0f4261714306e3ad248d8ab0951"},
|
||||
{file = "Brotli-1.1.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:8bf32b98b75c13ec7cf774164172683d6e7891088f6316e54425fde1efc276d5"},
|
||||
{file = "Brotli-1.1.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:7bc37c4d6b87fb1017ea28c9508b36bbcb0c3d18b4260fcdf08b200c74a6aee8"},
|
||||
{file = "Brotli-1.1.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3c0ef38c7a7014ffac184db9e04debe495d317cc9c6fb10071f7fefd93100a4f"},
|
||||
{file = "Brotli-1.1.0-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:91d7cc2a76b5567591d12c01f019dd7afce6ba8cba6571187e21e2fc418ae648"},
|
||||
{file = "Brotli-1.1.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a93dde851926f4f2678e704fadeb39e16c35d8baebd5252c9fd94ce8ce68c4a0"},
|
||||
{file = "Brotli-1.1.0-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f0db75f47be8b8abc8d9e31bc7aad0547ca26f24a54e6fd10231d623f183d089"},
|
||||
{file = "Brotli-1.1.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:6967ced6730aed543b8673008b5a391c3b1076d834ca438bbd70635c73775368"},
|
||||
{file = "Brotli-1.1.0-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:7eedaa5d036d9336c95915035fb57422054014ebdeb6f3b42eac809928e40d0c"},
|
||||
{file = "Brotli-1.1.0-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:d487f5432bf35b60ed625d7e1b448e2dc855422e87469e3f450aa5552b0eb284"},
|
||||
{file = "Brotli-1.1.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:832436e59afb93e1836081a20f324cb185836c617659b07b129141a8426973c7"},
|
||||
{file = "Brotli-1.1.0-cp313-cp313-win32.whl", hash = "sha256:43395e90523f9c23a3d5bdf004733246fba087f2948f87ab28015f12359ca6a0"},
|
||||
{file = "Brotli-1.1.0-cp313-cp313-win_amd64.whl", hash = "sha256:9011560a466d2eb3f5a6e4929cf4a09be405c64154e12df0dd72713f6500e32b"},
|
||||
{file = "Brotli-1.1.0-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:a090ca607cbb6a34b0391776f0cb48062081f5f60ddcce5d11838e67a01928d1"},
|
||||
{file = "Brotli-1.1.0-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2de9d02f5bda03d27ede52e8cfe7b865b066fa49258cbab568720aa5be80a47d"},
|
||||
{file = "Brotli-1.1.0-cp36-cp36m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2333e30a5e00fe0fe55903c8832e08ee9c3b1382aacf4db26664a16528d51b4b"},
|
||||
|
@ -208,10 +182,6 @@ files = [
|
|||
{file = "Brotli-1.1.0-cp36-cp36m-musllinux_1_1_i686.whl", hash = "sha256:fd5f17ff8f14003595ab414e45fce13d073e0762394f957182e69035c9f3d7c2"},
|
||||
{file = "Brotli-1.1.0-cp36-cp36m-musllinux_1_1_ppc64le.whl", hash = "sha256:069a121ac97412d1fe506da790b3e69f52254b9df4eb665cd42460c837193354"},
|
||||
{file = "Brotli-1.1.0-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:e93dfc1a1165e385cc8239fab7c036fb2cd8093728cbd85097b284d7b99249a2"},
|
||||
{file = "Brotli-1.1.0-cp36-cp36m-musllinux_1_2_aarch64.whl", hash = "sha256:aea440a510e14e818e67bfc4027880e2fb500c2ccb20ab21c7a7c8b5b4703d75"},
|
||||
{file = "Brotli-1.1.0-cp36-cp36m-musllinux_1_2_i686.whl", hash = "sha256:6974f52a02321b36847cd19d1b8e381bf39939c21efd6ee2fc13a28b0d99348c"},
|
||||
{file = "Brotli-1.1.0-cp36-cp36m-musllinux_1_2_ppc64le.whl", hash = "sha256:a7e53012d2853a07a4a79c00643832161a910674a893d296c9f1259859a289d2"},
|
||||
{file = "Brotli-1.1.0-cp36-cp36m-musllinux_1_2_x86_64.whl", hash = "sha256:d7702622a8b40c49bffb46e1e3ba2e81268d5c04a34f460978c6b5517a34dd52"},
|
||||
{file = "Brotli-1.1.0-cp36-cp36m-win32.whl", hash = "sha256:a599669fd7c47233438a56936988a2478685e74854088ef5293802123b5b2460"},
|
||||
{file = "Brotli-1.1.0-cp36-cp36m-win_amd64.whl", hash = "sha256:d143fd47fad1db3d7c27a1b1d66162e855b5d50a89666af46e1679c496e8e579"},
|
||||
{file = "Brotli-1.1.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:11d00ed0a83fa22d29bc6b64ef636c4552ebafcef57154b4ddd132f5638fbd1c"},
|
||||
|
@ -223,10 +193,6 @@ files = [
|
|||
{file = "Brotli-1.1.0-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:919e32f147ae93a09fe064d77d5ebf4e35502a8df75c29fb05788528e330fe74"},
|
||||
{file = "Brotli-1.1.0-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:23032ae55523cc7bccb4f6a0bf368cd25ad9bcdcc1990b64a647e7bbcce9cb5b"},
|
||||
{file = "Brotli-1.1.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:224e57f6eac61cc449f498cc5f0e1725ba2071a3d4f48d5d9dffba42db196438"},
|
||||
{file = "Brotli-1.1.0-cp37-cp37m-musllinux_1_2_aarch64.whl", hash = "sha256:cb1dac1770878ade83f2ccdf7d25e494f05c9165f5246b46a621cc849341dc01"},
|
||||
{file = "Brotli-1.1.0-cp37-cp37m-musllinux_1_2_i686.whl", hash = "sha256:3ee8a80d67a4334482d9712b8e83ca6b1d9bc7e351931252ebef5d8f7335a547"},
|
||||
{file = "Brotli-1.1.0-cp37-cp37m-musllinux_1_2_ppc64le.whl", hash = "sha256:5e55da2c8724191e5b557f8e18943b1b4839b8efc3ef60d65985bcf6f587dd38"},
|
||||
{file = "Brotli-1.1.0-cp37-cp37m-musllinux_1_2_x86_64.whl", hash = "sha256:d342778ef319e1026af243ed0a07c97acf3bad33b9f29e7ae6a1f68fd083e90c"},
|
||||
{file = "Brotli-1.1.0-cp37-cp37m-win32.whl", hash = "sha256:587ca6d3cef6e4e868102672d3bd9dc9698c309ba56d41c2b9c85bbb903cdb95"},
|
||||
{file = "Brotli-1.1.0-cp37-cp37m-win_amd64.whl", hash = "sha256:2954c1c23f81c2eaf0b0717d9380bd348578a94161a65b3a2afc62c86467dd68"},
|
||||
{file = "Brotli-1.1.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:efa8b278894b14d6da122a72fefcebc28445f2d3f880ac59d46c90f4c13be9a3"},
|
||||
|
@ -239,10 +205,6 @@ files = [
|
|||
{file = "Brotli-1.1.0-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:1ab4fbee0b2d9098c74f3057b2bc055a8bd92ccf02f65944a241b4349229185a"},
|
||||
{file = "Brotli-1.1.0-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:141bd4d93984070e097521ed07e2575b46f817d08f9fa42b16b9b5f27b5ac088"},
|
||||
{file = "Brotli-1.1.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:fce1473f3ccc4187f75b4690cfc922628aed4d3dd013d047f95a9b3919a86596"},
|
||||
{file = "Brotli-1.1.0-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:d2b35ca2c7f81d173d2fadc2f4f31e88cc5f7a39ae5b6db5513cf3383b0e0ec7"},
|
||||
{file = "Brotli-1.1.0-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:af6fa6817889314555aede9a919612b23739395ce767fe7fcbea9a80bf140fe5"},
|
||||
{file = "Brotli-1.1.0-cp38-cp38-musllinux_1_2_ppc64le.whl", hash = "sha256:2feb1d960f760a575dbc5ab3b1c00504b24caaf6986e2dc2b01c09c87866a943"},
|
||||
{file = "Brotli-1.1.0-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:4410f84b33374409552ac9b6903507cdb31cd30d2501fc5ca13d18f73548444a"},
|
||||
{file = "Brotli-1.1.0-cp38-cp38-win32.whl", hash = "sha256:db85ecf4e609a48f4b29055f1e144231b90edc90af7481aa731ba2d059226b1b"},
|
||||
{file = "Brotli-1.1.0-cp38-cp38-win_amd64.whl", hash = "sha256:3d7954194c36e304e1523f55d7042c59dc53ec20dd4e9ea9d151f1b62b4415c0"},
|
||||
{file = "Brotli-1.1.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:5fb2ce4b8045c78ebbc7b8f3c15062e435d47e7393cc57c25115cfd49883747a"},
|
||||
|
@ -255,10 +217,6 @@ files = [
|
|||
{file = "Brotli-1.1.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:949f3b7c29912693cee0afcf09acd6ebc04c57af949d9bf77d6101ebb61e388c"},
|
||||
{file = "Brotli-1.1.0-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:89f4988c7203739d48c6f806f1e87a1d96e0806d44f0fba61dba81392c9e474d"},
|
||||
{file = "Brotli-1.1.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:de6551e370ef19f8de1807d0a9aa2cdfdce2e85ce88b122fe9f6b2b076837e59"},
|
||||
{file = "Brotli-1.1.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:0737ddb3068957cf1b054899b0883830bb1fec522ec76b1098f9b6e0f02d9419"},
|
||||
{file = "Brotli-1.1.0-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:4f3607b129417e111e30637af1b56f24f7a49e64763253bbc275c75fa887d4b2"},
|
||||
{file = "Brotli-1.1.0-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:6c6e0c425f22c1c719c42670d561ad682f7bfeeef918edea971a79ac5252437f"},
|
||||
{file = "Brotli-1.1.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:494994f807ba0b92092a163a0a283961369a65f6cbe01e8891132b7a320e61eb"},
|
||||
{file = "Brotli-1.1.0-cp39-cp39-win32.whl", hash = "sha256:f0d8a7a6b5983c2496e364b969f0e526647a06b075d034f3297dc66f3b360c64"},
|
||||
{file = "Brotli-1.1.0-cp39-cp39-win_amd64.whl", hash = "sha256:cdad5b9014d83ca68c25d2e9444e28e967ef16e80f6b436918c700c117a85467"},
|
||||
{file = "Brotli-1.1.0.tar.gz", hash = "sha256:81de08ac11bcb85841e440c13611c00b67d3bf82698314928d0b676362546724"},
|
||||
|
@ -3731,4 +3689,4 @@ testing = ["coverage (>=5.0.3)", "zope.event", "zope.testing"]
|
|||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = ">=3.10,<4.0"
|
||||
content-hash = "b690d5fbd141da3947f4f1dc029aba1b95e7faafd723166f2c4bdc47a66c095e"
|
||||
content-hash = "271a6c2a76b1b6286e02b91489ffd0c42e92daf151ae932514f5416c7869f71d"
|
||||
|
|
|
@ -47,6 +47,11 @@ optional = true
|
|||
[tool.poetry.group.cuda.dependencies]
|
||||
onnxruntime-gpu = {version = "^1.17.0", source = "cuda12"}
|
||||
|
||||
[tool.poetry.group.rocm]
|
||||
optional = true
|
||||
|
||||
[tool.poetry.group.rocm.dependencies]
|
||||
|
||||
[tool.poetry.group.openvino]
|
||||
optional = true
|
||||
|
||||
|
|
176
machine-learning/rocm-PR19567.patch
Normal file
176
machine-learning/rocm-PR19567.patch
Normal file
|
@ -0,0 +1,176 @@
|
|||
From a598a88db258f82a6e4bca75810921bd6bcee7e0 Mon Sep 17 00:00:00 2001
|
||||
From: David Nieto <dmnieto@gmail.com>
|
||||
Date: Sat, 17 Feb 2024 11:23:12 -0800
|
||||
Subject: [PATCH] Disable algo caching in ROCM EP
|
||||
|
||||
Similar to the work done by Liangxijun-1001 in
|
||||
https://github.com/apache/tvm/pull/16178 the ROCM spec mandates calling
|
||||
miopenFindConvolution*Algorithm() before using any Convolution API
|
||||
|
||||
This is the link to the porting guide describing this requirement
|
||||
https://rocmdocs.amd.com/projects/MIOpen/en/latest/MIOpen_Porting_Guide.html
|
||||
|
||||
Thus, this change disables the algo cache and enforces the official
|
||||
API semantics
|
||||
|
||||
Signed-off-by: David Nieto <dmnieto@gmail.com>
|
||||
---
|
||||
onnxruntime/core/providers/rocm/nn/conv.cc | 61 +++++++++----------
|
||||
onnxruntime/core/providers/rocm/nn/conv.h | 6 --
|
||||
.../core/providers/rocm/nn/conv_transpose.cc | 17 +++---
|
||||
3 files changed, 36 insertions(+), 48 deletions(-)
|
||||
|
||||
diff --git a/onnxruntime/core/providers/rocm/nn/conv.cc b/onnxruntime/core/providers/rocm/nn/conv.cc
|
||||
index 6214ec7bc0ea..b08aceca48b1 100644
|
||||
--- a/onnxruntime/core/providers/rocm/nn/conv.cc
|
||||
+++ b/onnxruntime/core/providers/rocm/nn/conv.cc
|
||||
@@ -125,10 +125,8 @@ Status Conv<T, NHWC>::UpdateState(OpKernelContext* context, bool bias_expected)
|
||||
if (input_dims_changed)
|
||||
s_.last_x_dims = gsl::make_span(x_dims);
|
||||
|
||||
- if (w_dims_changed) {
|
||||
+ if (w_dims_changed)
|
||||
s_.last_w_dims = gsl::make_span(w_dims);
|
||||
- s_.cached_benchmark_fwd_results.clear();
|
||||
- }
|
||||
|
||||
ORT_RETURN_IF_ERROR(conv_attrs_.ValidateInputShape(X->Shape(), W->Shape(), channels_last, channels_last));
|
||||
|
||||
@@ -277,35 +275,6 @@ Status Conv<T, NHWC>::UpdateState(OpKernelContext* context, bool bias_expected)
|
||||
HIP_CALL_THROW(hipMalloc(&s_.b_zero, malloc_size));
|
||||
HIP_CALL_THROW(hipMemsetAsync(s_.b_zero, 0, malloc_size, Stream(context)));
|
||||
}
|
||||
-
|
||||
- if (!s_.cached_benchmark_fwd_results.contains(x_dims_miopen)) {
|
||||
- miopenConvAlgoPerf_t perf;
|
||||
- int algo_count = 1;
|
||||
- const ROCMExecutionProvider* rocm_ep = static_cast<const ROCMExecutionProvider*>(this->Info().GetExecutionProvider());
|
||||
- static constexpr int num_algos = MIOPEN_CONVOLUTION_FWD_ALGO_COUNT;
|
||||
- size_t max_ws_size = rocm_ep->GetMiopenConvUseMaxWorkspace() ? GetMaxWorkspaceSize(GetMiopenHandle(context), s_, kAllAlgos, num_algos)
|
||||
- : AlgoSearchWorkspaceSize;
|
||||
- IAllocatorUniquePtr<void> algo_search_workspace = GetTransientScratchBuffer<void>(max_ws_size);
|
||||
- MIOPEN_RETURN_IF_ERROR(miopenFindConvolutionForwardAlgorithm(
|
||||
- GetMiopenHandle(context),
|
||||
- s_.x_tensor,
|
||||
- s_.x_data,
|
||||
- s_.w_desc,
|
||||
- s_.w_data,
|
||||
- s_.conv_desc,
|
||||
- s_.y_tensor,
|
||||
- s_.y_data,
|
||||
- 1, // requestedAlgoCount
|
||||
- &algo_count, // returnedAlgoCount
|
||||
- &perf,
|
||||
- algo_search_workspace.get(),
|
||||
- max_ws_size,
|
||||
- false)); // Do not do exhaustive algo search.
|
||||
- s_.cached_benchmark_fwd_results.insert(x_dims_miopen, {perf.fwd_algo, perf.memory});
|
||||
- }
|
||||
- const auto& perf = s_.cached_benchmark_fwd_results.at(x_dims_miopen);
|
||||
- s_.fwd_algo = perf.fwd_algo;
|
||||
- s_.workspace_bytes = perf.memory;
|
||||
} else {
|
||||
// set Y
|
||||
s_.Y = context->Output(0, TensorShape(s_.y_dims));
|
||||
@@ -319,6 +288,34 @@ Status Conv<T, NHWC>::UpdateState(OpKernelContext* context, bool bias_expected)
|
||||
s_.y_data = reinterpret_cast<HipT*>(s_.Y->MutableData<T>());
|
||||
}
|
||||
}
|
||||
+ {
|
||||
+ /* FindConvolution must always be called by the runtime */
|
||||
+ TensorShapeVector x_dims_miopen{x_dims.begin(), x_dims.end()};
|
||||
+ miopenConvAlgoPerf_t perf;
|
||||
+ int algo_count = 1;
|
||||
+ const ROCMExecutionProvider* rocm_ep = static_cast<const ROCMExecutionProvider*>(this->Info().GetExecutionProvider());
|
||||
+ static constexpr int num_algos = MIOPEN_CONVOLUTION_FWD_ALGO_COUNT;
|
||||
+ size_t max_ws_size = rocm_ep->GetMiopenConvUseMaxWorkspace() ? GetMaxWorkspaceSize(GetMiopenHandle(context), s_, kAllAlgos, num_algos)
|
||||
+ : AlgoSearchWorkspaceSize;
|
||||
+ IAllocatorUniquePtr<void> algo_search_workspace = GetTransientScratchBuffer<void>(max_ws_size);
|
||||
+ MIOPEN_RETURN_IF_ERROR(miopenFindConvolutionForwardAlgorithm(
|
||||
+ GetMiopenHandle(context),
|
||||
+ s_.x_tensor,
|
||||
+ s_.x_data,
|
||||
+ s_.w_desc,
|
||||
+ s_.w_data,
|
||||
+ s_.conv_desc,
|
||||
+ s_.y_tensor,
|
||||
+ s_.y_data,
|
||||
+ 1, // requestedAlgoCount
|
||||
+ &algo_count, // returnedAlgoCount
|
||||
+ &perf,
|
||||
+ algo_search_workspace.get(),
|
||||
+ max_ws_size,
|
||||
+ false)); // Do not do exhaustive algo search.
|
||||
+ s_.fwd_algo = perf.fwd_algo;
|
||||
+ s_.workspace_bytes = perf.memory;
|
||||
+ }
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
diff --git a/onnxruntime/core/providers/rocm/nn/conv.h b/onnxruntime/core/providers/rocm/nn/conv.h
|
||||
index bc9846203e57..d54218f25854 100644
|
||||
--- a/onnxruntime/core/providers/rocm/nn/conv.h
|
||||
+++ b/onnxruntime/core/providers/rocm/nn/conv.h
|
||||
@@ -108,9 +108,6 @@ class lru_unordered_map {
|
||||
list_type lru_list_;
|
||||
};
|
||||
|
||||
-// cached miopen descriptors
|
||||
-constexpr size_t MAX_CACHED_ALGO_PERF_RESULTS = 10000;
|
||||
-
|
||||
template <typename AlgoPerfType>
|
||||
struct MiopenConvState {
|
||||
// if x/w dims changed, update algo and miopenTensors
|
||||
@@ -148,9 +145,6 @@ struct MiopenConvState {
|
||||
decltype(AlgoPerfType().memory) memory;
|
||||
};
|
||||
|
||||
- lru_unordered_map<TensorShapeVector, PerfFwdResultParams, vector_hash> cached_benchmark_fwd_results{MAX_CACHED_ALGO_PERF_RESULTS};
|
||||
- lru_unordered_map<TensorShapeVector, PerfBwdResultParams, vector_hash> cached_benchmark_bwd_results{MAX_CACHED_ALGO_PERF_RESULTS};
|
||||
-
|
||||
// Some properties needed to support asymmetric padded Conv nodes
|
||||
bool post_slicing_required;
|
||||
TensorShapeVector slice_starts;
|
||||
diff --git a/onnxruntime/core/providers/rocm/nn/conv_transpose.cc b/onnxruntime/core/providers/rocm/nn/conv_transpose.cc
|
||||
index 7447113fdf84..45ed4c8ac37a 100644
|
||||
--- a/onnxruntime/core/providers/rocm/nn/conv_transpose.cc
|
||||
+++ b/onnxruntime/core/providers/rocm/nn/conv_transpose.cc
|
||||
@@ -76,7 +76,6 @@ Status ConvTranspose<T, NHWC>::DoConvTranspose(OpKernelContext* context, bool dy
|
||||
|
||||
if (w_dims_changed) {
|
||||
s_.last_w_dims = gsl::make_span(w_dims);
|
||||
- s_.cached_benchmark_bwd_results.clear();
|
||||
}
|
||||
|
||||
ConvTransposeAttributes::Prepare p;
|
||||
@@ -127,12 +126,13 @@ Status ConvTranspose<T, NHWC>::DoConvTranspose(OpKernelContext* context, bool dy
|
||||
|
||||
y_data = reinterpret_cast<HipT*>(p.Y->MutableData<T>());
|
||||
|
||||
- if (!s_.cached_benchmark_bwd_results.contains(x_dims)) {
|
||||
- IAllocatorUniquePtr<void> algo_search_workspace = GetScratchBuffer<void>(AlgoSearchWorkspaceSize, context->GetComputeStream());
|
||||
-
|
||||
- miopenConvAlgoPerf_t perf;
|
||||
- int algo_count = 1;
|
||||
- MIOPEN_RETURN_IF_ERROR(miopenFindConvolutionBackwardDataAlgorithm(
|
||||
+ }
|
||||
+ // The following is required before calling convolution, we cannot cache the results
|
||||
+ {
|
||||
+ IAllocatorUniquePtr<void> algo_search_workspace = GetScratchBuffer<void>(AlgoSearchWorkspaceSize, context->GetComputeStream());
|
||||
+ miopenConvAlgoPerf_t perf;
|
||||
+ int algo_count = 1;
|
||||
+ MIOPEN_RETURN_IF_ERROR(miopenFindConvolutionBackwardDataAlgorithm(
|
||||
GetMiopenHandle(context),
|
||||
s_.x_tensor,
|
||||
x_data,
|
||||
@@ -147,10 +147,7 @@ Status ConvTranspose<T, NHWC>::DoConvTranspose(OpKernelContext* context, bool dy
|
||||
algo_search_workspace.get(),
|
||||
AlgoSearchWorkspaceSize,
|
||||
false));
|
||||
- s_.cached_benchmark_bwd_results.insert(x_dims, {perf.bwd_data_algo, perf.memory});
|
||||
- }
|
||||
|
||||
- const auto& perf = s_.cached_benchmark_bwd_results.at(x_dims);
|
||||
s_.bwd_data_algo = perf.bwd_data_algo;
|
||||
s_.workspace_bytes = perf.memory;
|
||||
}
|
Loading…
Reference in a new issue