from pathlib import Path import onnx import onnxruntime as ort import onnxsim def save_onnx(model: onnx.ModelProto, output_path: Path | str) -> None: try: onnx.save(model, output_path) except ValueError as e: if "The proto size is larger than the 2 GB limit." in str(e): onnx.save(model, output_path, save_as_external_data=True, size_threshold=1_000_000) else: raise e def optimize_onnxsim(model_path: Path | str, output_path: Path | str) -> None: model_path = Path(model_path) output_path = Path(output_path) model = onnx.load(model_path.as_posix()) model, check = onnxsim.simplify(model) assert check, "Simplified ONNX model could not be validated" for file in model_path.parent.iterdir(): if file.name.startswith("Constant") or "onnx" in file.name or file.suffix == ".weight": file.unlink() save_onnx(model, output_path) def optimize_ort( model_path: Path | str, output_path: Path | str, level: ort.GraphOptimizationLevel = ort.GraphOptimizationLevel.ORT_ENABLE_BASIC, ) -> None: model_path = Path(model_path) output_path = Path(output_path) sess_options = ort.SessionOptions() sess_options.graph_optimization_level = level sess_options.optimized_model_filepath = output_path.as_posix() ort.InferenceSession(model_path.as_posix(), providers=["CPUExecutionProvider"], sess_options=sess_options) def optimize(model_path: Path | str) -> None: model_path = Path(model_path) optimize_ort(model_path, model_path) optimize_onnxsim(model_path, model_path)