-
Notifications
You must be signed in to change notification settings - Fork 0
Open
Labels
Description
- 安裝 CUDA,選好版本複製命令到Python上安裝即可,1050Ti我安裝11.7
https://pytorch.org/get-started/locally/
pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu117
作者源文提點
开始本教程前请先前往[pytorch](https://pytorch.org/get-started/locally/) 官网查看自己系统与硬件支持的pytorch版本
注意30系列之前的N卡,如2080Ti等请选择cuda11以下的版本(例:CUDA 10.2)
如果为30系N卡,仅支持CUDA 11版本,请选择CUDA 11以上版本(例:CUDA 11.3),
后根据选择的条件显示的pytorch安装命令完成pytorch安装,
由于pytorch的版本更新速度导致很多pypi源仅缓存了cpu版本,CUDA版本需要自己在官网安装。
- Clone dddd_trainer
git clone https://github.com/sml2h3/dddd_trainer.git- Pip Install
pip install -r .\requirements.txt
pip install onnx # 訓練完時如果沒有這個庫會報 no module named 'onnx'
pip install Pillow==9.5.0 # 為了解決 [AttributeError: module 'PIL.Image' has no attribute 'ANTIALIAS'](https://stackoverflow.com/questions/76616042/attributeerror-module-pil-image-has-no-attribute-antialias)- 新建訓練用專案
python app.py create {project_name}- 此時
projects會有專案名稱,在底下新建自己的訓練圖集資料夾 (以下使用images)
用於訓練的圖片量建議高於1200張,並且最好是 JPG 格式,訓練時間順利的話1~10分鐘,如果ACC一直卡住代表訓練集太少,假設卡 0.3 ACC,就再準備三倍以上的量再重新訓練較好
# 轉換PNG圖片至JPG
import os
from PIL import Image
import glob
for infile in glob.glob("images/*.png"):
file, ext = os.path.splitext(infile)
im = Image.open(infile)
rgb_im = im.convert('RGB')
rgb_im.save(file + ".jpg", "JPEG", quality=100)
os.remove(infile)注意:訓練圖如果是透明圖會訓練非常久,可以轉成白底
# 此時的目錄結構可能是
projects\{project_name}\images\
AAAXCT_0000.png
AABABN_0000.png-
根據需求修改 config.yaml,看不懂就先不用管
-
緩存訓練集,image_folder建議絕對路徑,否則可能會有各種奇怪問題
python app.py cache {project_name} images- 開始訓練
python app.py train {project_name}-
訓練完畢會生成在 {project_name}/{models},裡面會有兩個檔案都要丟到你實際使用ddddocr專案目錄
-
ddddocr 初始化代碼
ocr = ddddocr.DdddOcr(show_ad=False, det=False, ocr=False, import_onnx_path="model/oxxn_1.0_25_9000_2023-08-12-15-27-18.onnx",
charsets_path="model/charsets.json")- 完整測試代碼
import threading
import time
from os import listdir
import ddddocr
ocr = ddddocr.DdddOcr(show_ad=False, det=False, ocr=False, import_onnx_path="model/oxxn_1.0_25_9000_2023-08-12-15-27-18.onnx",
charsets_path="model/charsets.json")
image_src = r'..\projects\oxxn\images'
accuracy = 0
start_time = time.time()
def process(image_src):
global accuracy
with open(image_src, 'rb') as f:
image_bytes = f.read()
answer = image_src.split('\\')[-1].split('.')[0].split('_')[0]
res = ocr.classification(image_bytes)
print(f'{answer} OCR結果:{res} 結果:{answer == res}')
if answer == res:
accuracy += 1
threads = []
for i in listdir(image_src):
src = image_src + '\\' + i
threads.append(threading.Thread(target=process, args=(src,)))
for thread in threads:
thread.start()
for thread in threads:
thread.join()
print(f'正確率:{accuracy / len(listdir(image_src)) * 100} %')
print(f'正確數:{accuracy} / {len(listdir(image_src))}')
print(f'花費時間:{time.time() - start_time}秒')Reactions are currently unavailable