Всем привет! Чтобы классифицировать конструкции коробок, которые могли бы принести пользу моей упаковочной компании, я недавно создал приложение для машинного обучения, используя платформу Jetson. Цель состояла в том, чтобы автоматизировать такие задачи, как печать ценников на коробках на основе дизайна и предотвращение ошибок, таких как отправка не тех коробок в больших количествах. В следующих разделах я кратко объясню настройку и код.

Для сборки приложения у нас должно быть следующее оборудование,

  • Jetson AGX Xavier (мы могли бы также использовать более простое устройство, такое как nano)
  • USB-камера (Logitech)
  • Мышь и клавиатура
  • Монитор
  • сетевой кабель

И последнее, но не менее важное: хост-система Linux для установки JetPack и загрузки Jetson AGX Xavier. NVIDIA также предоставляет возможность загрузки с SD-карты. Чтобы узнать больше, прочитайте следующую ссылку,



Подключите их и начните с установки JetPack через SDK Manager. Я установил его по официальной ссылке (https://developer.nvidia.com/embedded/jetpack) и имею следующие версии JetPack и Cuda,

  • Версия реактивного ранца: 5.1.1 [L4T 35.3.1]
  • Версия Cuda: инструменты компиляции Cuda, выпуск 11.4, V11.4.315.

После этого мы можем использовать образ докера от NVIDIA, чтобы запустить контейнер среды выполнения и приступить к практической работе.

Ссылка: https://catalog.ngc.nvidia.com/orgs/nvidia/containers/l4t-base

Обратите внимание, что после того, как мы настроили наш Jetson для загрузки, мы можем разработать приложение двумя способами. Один через режим ssh, а другой через обычный ПК, например настройку. Я предпочел обычный подход, так как был более склонен к созданию приложения Python на основе Tkinter, которому требовалось устройство отображения для тестирования пользовательского интерфейса.

Структура папок проекта приведена ниже,

Мы будем использовать вспомогательный класс ImageClassificationDataset из helpers/dataset.py, который сохраняет изображения в соответствующие папки меток с желаемым именем набора данных.

class ImageClassificationDataset(torch.utils.data.Dataset):
    
    def __init__(self, directory, categories, transform=None):
        self.categories = categories
        self.directory = directory
        self.transform = transform
        self._refresh()
    
    
    def __len__(self):
        return len(self.annotations)
    
    
    def __getitem__(self, idx):
        ann = self.annotations[idx]
        image = cv2.imread(ann['image_path'], cv2.IMREAD_COLOR)
        image = PIL.Image.fromarray(image)
        if self.transform is not None:
            image = self.transform(image)
        return image, ann['category_index']
    
    
    def _refresh(self):
        self.annotations = []
        for category in self.categories:
            category_index = self.categories.index(category)
            for image_path in glob.glob(os.path.join(self.directory, category, '*.jpg')):
                self.annotations += [{
                    'image_path': image_path,
                    'category_index': category_index,
                    'category': category
                }]
    
    def save_entry(self, image, category):
        """Saves an image in BGR8 format to dataset for category"""
        if category not in self.categories:
            raise KeyError('There is no category named %s in this dataset.' % category)
            
        filename = str(uuid.uuid1()) + '.jpg'
        category_directory = os.path.join(self.directory, category)
        
        if not os.path.exists(category_directory):
            subprocess.call(['mkdir', '-p', category_directory])
            
        image_path = os.path.join(category_directory, filename)
        cv2.imwrite(image_path, image)
        self._refresh()
        return image_path
    
    def get_count(self, category):
        i = 0
        for a in self.annotations:
            if a['category'] == category:
                i += 1
        return i

Как видите, конструктор принимает параметр каталога, соответствующий имени набора данных, а затем параметр категории, определяющий метки, которые должны быть размещены в наборе данных.

Разобравшись с помощником проекта и набора данных, давайте перейдем к сбору данных с последующим обучением и классификацией.

Сбор данных

Чтобы все модули использовали одну и ту же конфигурацию, я определил config.json следующим образом:

{
 "config": {
  "task": "diff",
  "datasets": [
   "2023q1",
   "2023q2"
  ],
  "labels": [
            "None",
   "Floral101",
   "Floral102"
        ]
 }
}

После этого я также создал функцию чтения внутри файла read.py,

from dataset import ImageClassificationDataset
import torchvision.transforms as transforms
import json

def read_dataset():
    f = open('config.json')
    json_data = json.load(f)
    TRANSFORMS = transforms.Compose([
        transforms.ColorJitter(0.2, 0.2, 0.2, 0.2),
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    datasets = {}
    for name in json_data["config"]["datasets"]:
        datasets[name] = ImageClassificationDataset(json_data["config"]["task"] + '_' + name,  json_data["config"]["labels"], TRANSFORMS)
    dataset = datasets[json_data["config"]["datasets"][0]]
    
    print(json_data)
    print(json_data["config"]["task"])
    print("{0} task with {1} categories defined".format(json_data["config"]["task"], json_data["config"]["labels"]))
    
    categories = json_data["config"]["labels"]
    
    return (datasets, dataset)

Эта функция чтения использует файл json для создания каталога набора данных как «diff_2023q1» и пометки папок как «Нет», «Цветочный101», «Цветочный102» соответственно.

Вызов функции read_dataset подготовит нас к созданию следующей структуры папок. По мере сбора данных с помощью нашего приложения для сбора данных мы будем больше понимать структуру набора данных.

Ниже приведен код из data_collection.py,

#append src folder to read dataset.py from it
import sys
sys.path.append('helpers')

import tkinter
from tkinter import *
from tkinter import ttk
import numpy as np
from PIL import Image, ImageTk
import cv2
from jetcam.usb_camera import USBCamera
from jetcam.utils import bgr8_to_jpeg
import os
from tkinter import messagebox

from dataset import ImageClassificationDataset
import torchvision.transforms as transforms

from reader import read_dataset

global cam
global currFrame
global datasets
global dataset
global countVar
global root

def prepare():
    global datasets
    global dataset

    datasets, dataset = read_dataset()
    
def get_label_selection():
    print("Dataset label: {}".format(value_inside.get()))
    return value_inside.get()

def update_image(change):
    global currFrame
    image = change['new']
    currFrame = image
    
def start():
    global frame
    global cam
    global currFrame
    global countVar
    
    option = get_label_selection()
    if option not in dataset.categories:
        messagebox.showerror('Error', 'Please select a label!')
        return
    
    countVar.set("Image count : {0}".format(dataset.get_count(option)))

    # cam = cv2.VideoCapture(0)
    #cv2.namedWindow("Experience_in_AI camera")
    while cam.running:
        frame = currFrame #cam.read()

        #Update the image to tkinter...
        frame=cv2.cvtColor(frame,cv2.COLOR_BGR2RGB)
        PIL_image = Image.fromarray(np.uint8(frame)).convert('RGB')
        img_update = ImageTk.PhotoImage(PIL_image)
        paneeli_image.configure(image=img_update)
        paneeli_image.image=img_update
        paneeli_image.update()

        if frame.shape[0] == 0:
            print("failed to grab frame")
            break

def stop():
    global cam
    global root
    
    cam.running = False
    cam.unobserve(update_image, names='value')
    root.quit()
    print("Stopped!")
    
def save(event=None):
    global dataset
    global cam
    
    option = get_label_selection()
        
    print("save called")
    dataset.save_entry(cam.value, get_label_selection())
    countVar.set("Image count : {0}".format(dataset.get_count(option)))
    
def on_select(event):
    print("on_select called")
    option = get_label_selection()
    countVar.set("Image count : {0}".format(dataset.get_count(option)))

# create_dataset_folder()

prepare()

cam = USBCamera(width=224, height=224, capture_width=640, capture_height=480, capture_device=0)
image = cam.read()

cam.running = True
cam.observe(update_image, names='value')
    
print(image.shape)

root=tkinter.Tk()
root.title("Dataset creator app")

countVar = StringVar()
countVar.set("Image count : {0}".format("-"))

frame=np.random.randint(0,255,[100,100,3],dtype='uint8')
img = ImageTk.PhotoImage(Image.fromarray(frame))

paneeli_image=tkinter.Label(root) #,image=img)
paneeli_image.grid(row=0,column=0,columnspan=3,pady=1,padx=10)
  
# Variable to keep track of the option
# selected in OptionMenu
value_inside = tkinter.StringVar(root)
  
# Set the default value of the variable
value_inside.set("Select a Label")
  
# Create the optionmenu widget and passing 
# the options_list and value_inside to it.
question_menu = tkinter.OptionMenu(root, value_inside, *dataset.categories, command = on_select)
question_menu.grid(row=1,column=1,pady=1,padx=10)

component_height=5
startButton=tkinter.Button(root,text="Start",command=start,height=5,width=20)
startButton.grid(row=1,column=0,pady=10,padx=10)
startButton.config(height=1*component_height,width=10)

component_height=5
stopButton=tkinter.Button(root,text="Exit",command=stop,height=5,width=20)
stopButton.grid(row=1,column=2,pady=10,padx=10)
stopButton.config(height=1*component_height,width=10)

label = Label(root, textvariable=countVar)
font=('Calibri 14 bold')
label.grid(row=2,column=1,pady=10,padx=10)

label = Label(root, text="Use spacebar to save a snapshot")
font=('Calibri 12 bold')
label.grid(row=3,column=1,pady=10,padx=10)

root.bind("<space>", save)

root.mainloop()
  • «prepare()» создает дескриптор для организации изображений под соответствующими метками.
  • «get_label_selection()» возвращает текущую выбранную опцию в виджете, которая является интересующей меткой.
  • Функция «update_image()» прикреплена в качестве функции наблюдателя к экземпляру камеры jetcam, которая вызывается при каждом кадре. Эта функция создает ссылку на текущий кадр для виджета изображения tkinter для отображения текущего кадра (для прямой трансляции).
  • Функция «start()» запускает цикл, который обновляет виджет изображения текущим кадром.
  • Функция «stop()» удаляет текущий наблюдатель кадра из экземпляра камеры и завершает работу приложения.
  • Функция «сохранить ()» сохраняет изображение в соответствующую папку меток в наборе данных и сопоставляется с «слушателем пробела». Действие сохранения происходит каждый раз, когда нажимается пробел.
  • Функция «on_select()» — это команда, сопоставленная с полем параметров, которая вызывается каждый раз, когда параметр изменяется посредством взаимодействия.
  • Остальной код проходит через создание пользовательского интерфейса приложения tkinter.

Запустив data_collection.py, мы должны визуализировать следующий результат:

Обучение

Чтобы обучить нашу модель классификации изображений, мы будем использовать архитектуру ResNet18.

Ниже приведен код train.py,

import sys
sys.path.append('helpers')

import queue

import tkinter
from tkinter import *
from tkinter import ttk
from tkinter import messagebox

import numpy as np
from PIL import Image, ImageTk
import cv2

from dataset import ImageClassificationDataset
from reader import read_dataset

import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn.functional as F

import os
import json

import threading
import time
from utils import preprocess

global modelPathTF
global epochTF
global modelPathVar
global model
global device
global dataset
global lossValVar
global accValVar
global trainButton
global progBar
global root
global the_queue
global trainThread

def refresh_data():
        global the_queue
        global root
        
        global accValVar
        global lossValVar
        global progBar
        global modelPathTF

        global trainThread

        #print("thread status - ", trainThread.is_alive())

        if not trainThread.is_alive():
            save_model(modelPathTF.get())
            return
        
        # refresh the GUI with new data from the queue
        while not the_queue.empty():
            key, data = the_queue.get()
            #print("value from queue : {0}, {1}", key, data)
            if key == "prog":
                progBar["value"] = data
            elif key == "loss":
                lossValVar.set(data)
            elif key == "accu":
                accValVar.set(data)

        #  timer to refresh the gui with data from the asyncio thread
        root.after(100, refresh_data)  # called only once!

def start():
    global modelPathTF
    global epochTF
    global trainButton
    global the_queue
    global root
    
    global trainThread
    
    print("---")
    if int(epochTF.get()) <= 0:
        messagebox.showerror('Error', 'Provide an epoch value greater than 0')
        return
    
    trainButton.config(state="disabled")
    
    root.after(1, train_prepare(modelPathTF.get()))
    root.after(100, refresh_data);

    trainThread = threading.Thread(target=train, args=(the_queue, int(epochTF.get()), True))
    trainThread.start()
        
def stop():
    global root
    print("---")
    root.quit()
    print("Stopped!")

def launch():
    global modelPathTF
    global epochTF
    global modelPathVar
    global lossValVar
    global accValVar
    global trainButton
    global progBar
    global root

    root=tkinter.Tk()
    root.title("Train classifier")

    countVar = StringVar()
    countVar.set("Image count : {0}".format("-"))
    
    accValVar = StringVar()
    lossValVar = StringVar()

    valueInside = tkinter.StringVar(root)

    valueInside.set("Select a Label")

    label = Label(root, text="Model Path")
    font=('Calibri 12 bold')
    label.grid(row=0,column=0,pady=10,padx=10)

    modelPathVar = tkinter.StringVar()
    modelPathTF = tkinter.Entry( root, textvariable=modelPathVar)
    modelPathVar.set("models/model_v1.pth")
    # entry1.place(x = 80, y = 50)  
    modelPathTF.grid(row=0,column=1,pady=10,padx=5)

    label = Label(root, text="Epocs")
    font=('Calibri 12 bold')
    label.grid(row=1,column=0,pady=10,padx=10)

    epocValVar = tkinter.StringVar()
    epochTF = Entry(root, textvariable=epocValVar)
    epocValVar.set("0")
    # entry1.place(x = 80, y = 50)  
    epochTF.grid(row=1,column=1,pady=10,padx=5,columnspan=1)

    label = Label(root, text="Epocs")
    font=('Calibri 12 bold')
    label.grid(row=1,column=0,pady=10,padx=10)

    label = Label(root, text="Current epoch")
    font=('Calibri 12 bold')
    label.grid(row=2,column=0,pady=10,padx=10)

    progBar = ttk.Progressbar(root,orient=HORIZONTAL, length=200,mode="determinate")
    progBar.grid(row=2,column=1,pady=10,padx=5)
    progBar['value']=0

    accLabel = Label(root, text="Accuracy")
    font=('Calibri 12 bold')
    accLabel.grid(row=3,column=0,pady=10,padx=10)

    accValLabel = Label(root, textvariable=accValVar)
    accValVar.set("{0}".format("-"))
    font=('Calibri 12 normal')
    accValLabel.grid(row=3,column=1,pady=10,padx=10)

    lossLabel = Label(root, text="Loss")
    font=('Calibri 12 bold')
    lossLabel.grid(row=4,column=0,pady=10,padx=10)

    lossValLabel = Label(root, textvariable=lossValVar)
    lossValVar.set("{0}".format("-"))
    font=('Calibri 12 normal')
    lossValLabel.grid(row=4,column=1,pady=10,padx=10)

    componentHeight=2
    trainButton=tkinter.Button(root,text="Train",command=start,height=5,width=20)
    trainButton.grid(row=5, column=0,sticky='W',pady=10,padx=10)
    trainButton.config(height=1*componentHeight,width=5, padx=50)

    componentHeight=2
    stopButton=tkinter.Button(root,text="Exit",command=stop,height=5,width=20)
    stopButton.grid(row=5, column=1,sticky='E',pady=10,padx=10)
    stopButton.config(height=1*componentHeight,width=5, padx=50)

    root.resizable(0, 0)
    root.mainloop()
    
def load_model(path):
    global model
    model.load_state_dict(torch.load(path))

def save_model(path):
    global model
    parDir = "models"
    isExist = os.path.exists(parDir)
    if not isExist:
       # Create a new directory because it does not exist
       os.makedirs(parDir)
    torch.save(model.state_dict(), path)
    
def train_prepare(modelPath):
    print(type(modelPath), modelPath)
    global model
    global device
    global dataset
    
    time.sleep(1)
    
    datasets, dataset = read_dataset()
    
    print("fc layer output - {0}".format(len(dataset.categories)))
    
    device = torch.device('cuda')
    
    # RESNET 18
    model = torchvision.models.resnet18(pretrained=True)
    model.fc = torch.nn.Linear(512, len(dataset.categories))
    
    model = model.to(device)
    
    isExist = os.path.exists(modelPath)
    if isExist:
        load_model(modelPath)
    else:
        print("model not found in current path, may be this is the first time you are about to train!");

    # display(model_widget)
    print("model configured and model_widget created")

def train(the_queue, epochs, is_training):
    global BATCH_SIZE, LEARNING_RATE, MOMENTUM, model, dataset, optimizer
    global device
    
    print("all data types - ", type(the_queue), type(epochs), type(is_training))
    total_epocs = epochs
        
    BATCH_SIZE = 3
    optimizer = torch.optim.Adam(model.parameters())
    elspased_epocs = 0
    
    try:
        train_loader = torch.utils.data.DataLoader(
            dataset,
            batch_size=BATCH_SIZE,
            shuffle=True
        )

        time.sleep(1)

        if is_training:
            model = model.train()
        else:
            model = model.eval()
        while total_epocs > 0:
            i = 0
            sum_loss = 0.0
            error_count = 0.0
            for images, labels in iter(train_loader):
                # send data to device
                images = images.to(device)
                labels = labels.to(device)

                if is_training:
                    # zero gradients of parameters
                    optimizer.zero_grad()

                # execute model to get outputs
                outputs = model(images)

                # compute loss
                loss = F.cross_entropy(outputs, labels)

                if is_training:
                    # run backpropogation to accumulate gradients
                    loss.backward()

                    # step optimizer to adjust parameters
                    optimizer.step()

                # increment progress
                error_count += len(torch.nonzero(outputs.argmax(1) - labels).flatten())
                count = len(labels.flatten())
                i += count
                sum_loss += float(loss)
                
                # print("loss actual - {0}", sum_loss / i)
                # print("accuracy actual - {0}", 1.0 - error_count / i)

                the_queue.put(("accu", "{0}".format(1.0 - error_count / i)))
                the_queue.put(("loss", "{0}".format(sum_loss / i)))
        
                # lossValVar.set("{0}".format(sum_loss / i))
                # accValVar.set("{0}".format(1.0 - error_count / i))
            
            elspased_epocs = elspased_epocs + 1
            if is_training:
                total_epocs = total_epocs - 1
                print("elspased_epocs", elspased_epocs)
                # epochs_widget.value = epochs_widget.value - 1
                progBar['value'] = elspased_epocs * (100 / epochs)
            else:
                break
    except Exception as ex:
        print(ex)
        pass
    model = model.eval()
    
trainThread = None
the_queue = queue.Queue()

launch()

  • Функция «refresh_data ()» вызывается tkinter, а затем рекурсивно, если поток обучения не завершится. Такая настройка нужна для того, чтобы обновить элементы пользовательского интерфейса tkinter. Нам важно отметить, что refresh_data() использует очередь. Почему? Это связано с тем, что tkinter работает в основном потоке и, следовательно, его виджеты не могут быть изменены из отдельного потока. Из-за этого я прибегал к использованию очереди, в которой обучающий поток помещает данные в очередь, а refresh_data() считывает их для обновления виджетов пользовательского интерфейса.
  • Функция «start()» создает дескриптор для чтения изображений под соответствующими метками из набора данных, затем регистрирует refresh_data() для рекурсивного вызова, а затем запускает обучающий поток для обучения модели.
  • Функция «launch()» создает пользовательский интерфейс tkinter и регистрирует переменные и команды виджета. Пользовательский интерфейс предлагает возможность указать путь для сохранения модели. Кроме того, он также позволяет указать количество эпох для обучения модели.
  • Функция «save_model()» сохраняет модель после того, как обучающий поток завершает свою работу по обучению и завершает работу.
  • Функция «train_prepare()» подготавливает приложение к обучению, создавая ссылку на предварительно обученную модель ResNet18, а также инициализируя torch в режиме Cuda для обучения.
  • Функция «train()» обучает модель заданному количеству итераций и обновляет информацию о ходе выполнения в очереди, такую ​​как точность и потери на итерацию. Основной поток берет данные из очереди и соответствующим образом обновляет пользовательский интерфейс.
  • Пожалуйста, обратите внимание на потери и точность. Потери должны уменьшаться со временем, а точность должна постоянно улучшаться, чтобы достичь хорошего уровня. Если мы видим отклонения, мы должны пересмотреть наш набор данных.
  • После успешного обучения мы должны увидеть файл с именем model_v1.pth в папке моделей.

После запуска файла train.py мы должны увидеть следующие результаты:

Классификация

Мы будем использовать прямую трансляцию с камеры для классификации дизайнов коробок в режиме реального времени.

Ниже приведен код classify.py,

#append src folder to read dataset.py from it
import sys
sys.path.append('helpers')

import tkinter
from tkinter import *
from tkinter import ttk
import numpy as np
from PIL import Image, ImageTk
import cv2
from jetcam.usb_camera import USBCamera
from jetcam.utils import bgr8_to_jpeg

import os
import threading
import time
import queue

from utils import preprocess

from tkinter import messagebox

from dataset import ImageClassificationDataset

import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn.functional as F

from reader import read_dataset

global frame
global cam
global currFrame
global classificationVar
global classifyThread
global root
global theQueue

def update_image(change):
    global currFrame
    image = change['new']
    currFrame = image
    
def refresh_data():
    global theQueue
    global root
    global classificationVar
    global classifyThread

    #print("thread status - ", classifyThread.is_alive())

    if not classifyThread.is_alive():
        return

    # refresh the GUI with new data from the queue
    while not theQueue.empty():
        key, data = theQueue.get()
        #print("value from queue : {0}, {1}", key, data)
        if key == "result":
            classificationVar.set(data)

    #  timer to refresh the gui with data from the asyncio thread
    root.after(100, refresh_data)  # called only once!
        
def start():
    global frame
    global cam
    global currFrame
    global classificationVar
    global root
    global classifyThread
    
    startButton.config(state="disabled")

    datasets, dataset = read_dataset()
        
    device = torch.device('cuda')
   
    model = torchvision.models.resnet18(pretrained=True)
    model.fc = torch.nn.Linear(512, len(dataset.categories))

    model = model.to(device)

    model.load_state_dict(torch.load(modelPathTF.get()))
    
    root.after(100, refresh_data);
        
    classifyThread = threading.Thread(target=live, args=((cam, dataset, model, theQueue)))
    classifyThread.start()
    
    while cam.running:
        frame = currFrame #cam.read()

        #Update the image to tkinter...
        frame=cv2.cvtColor(frame,cv2.COLOR_BGR2RGB)
        PIL_image = Image.fromarray(np.uint8(frame)).convert('RGB')
        img_update = ImageTk.PhotoImage(PIL_image)
        imageView.configure(image=img_update)
        imageView.image=img_update
        imageView.update()

        if frame.shape[0] == 0:
            print("failed to grab frame")
            break

def live(camera, dataset, model, theQueue):
    while cam.running:
        image = camera.value
        preprocessed = preprocess(image)
        output = model(preprocessed)
        output = F.softmax(output, dim=1).detach().cpu().numpy().flatten()
        category_index = output.argmax()
        #print("---prediction---")
        
        predictedclass = dataset.categories[category_index]
        #print(predictedclass)
        for i, score in enumerate(list(output)):
            #print(i, score)
            if dataset.categories[i] == predictedclass:
                rounded = round(float(score), 2)
                # theQueue.put(("result", "{0} @ {1}".format(predictedclass, rounded)))
                theQueue.put(("result", "{0}".format(predictedclass)))

        #print("---end---")
            
def stop():
    global cam
    global root
    
    cam.running = False
    cam.unobserve(update_image, names='value')
    root.quit()
    print("Stopped!")

classifyThread = None
theQueue = queue.Queue()
    
cam = USBCamera(width=224, height=224, capture_width=640, capture_height=480, capture_device=0)
image = cam.read()

cam.running = True
cam.observe(update_image, names='value')
    
print(image.shape)

root=tkinter.Tk()
root.title("Classifier app")

frame=np.random.randint(0,255,[100,100,3],dtype='uint8')
img = ImageTk.PhotoImage(Image.fromarray(frame))

imageView=tkinter.Label(root) #,image=img)
imageView.grid(row=0,column=0,columnspan=3,pady=1,padx=10)

modelPathVar = tkinter.StringVar()
modelPathTF = tkinter.Entry( root, textvariable=modelPathVar)
modelPathVar.set("models/model_v1.pth")
# entry1.place(x = 80, y = 50)  
modelPathTF.grid(row=1,column=1,pady=10,padx=5)

label = Label(root, text="Classification result")
font=('Calibri 14 bold')
label.grid(row=2,column=1,pady=0,padx=10)

classificationVar = tkinter.StringVar()
label = Label(root, textvariable=classificationVar)
classificationVar.set("Waiting...")
font=('Calibri 12')
label.config(fg="#0000FF")
label.config(bg="yellow")
label.grid(row=3,column=1,pady=10,padx=10)

component_height=2
startButton=tkinter.Button(root,text="Start",command=start,height=5,width=20)
startButton.grid(row=4,column=0,pady=10,padx=10)
startButton.config(height=1*component_height,width=5)

component_height=2
stopButton=tkinter.Button(root,text="Exit",command=stop,height=5,width=20)
stopButton.grid(row=4,column=2,pady=10,padx=10)
stopButton.config(height=1*component_height,width=5)

root.mainloop()
  • Функция «update_image()» сопоставляется с экземпляром камеры jetcam, который вызывается для каждого кадра. Эта функция создает ссылку на текущий кадр для виджета изображения tkinter для отображения прямой трансляции.
  • Функция «refresh_data ()» вызывается tkinter, а затем рекурсивно, если поток классификации не завершится. Он запускается рекурсивно для обновления элементов пользовательского интерфейса tkinter. Функция refresh_data() считывает очередь, обновленную потоком классификации, для обновления виджетов пользовательского интерфейса.
  • Функция «start()» запускает цикл, который обновляет виджет изображения текущим кадром. Затем Torch инициализируется в режиме Cuda с последующей загрузкой обученной модели ResNet18 по пути, указанному в виджете ввода. Наконец, поток классификации создается и запускается.
  • Функция «live()» берет текущий кадр камеры и отправляет его в модель в нужной форме, чтобы получить результат логического вывода. Результат классификации и оценка помещаются в очередь для обновления пользовательского интерфейса с помощью функции refresh_data().
  • Функция «stop()» удаляет текущий наблюдатель кадра из экземпляра камеры и завершает работу приложения.
  • Остальной код проходит через создание пользовательского интерфейса приложения tkinter.
  • Обратите внимание, что метка с желтой подсветкой указывает на текущий результат классификации.

Запуск classify.py дает следующий результат:

Я надеюсь, что это руководство послужит хорошей отправной точкой для практики нашего первого проекта машинного обучения с использованием Jetson.

Я добавил ссылку на код проекта в github. Не стесняйтесь разветвлять его и использовать повторно. Любые вопросы, которые у вас могут возникнуть, пишите, я постараюсь ответить на них в выходные.



Я также добавил демо-видео, набор данных и модель на,



Добрый день!