Всем привет! Чтобы классифицировать конструкции коробок, которые могли бы принести пользу моей упаковочной компании, я недавно создал приложение для машинного обучения, используя платформу Jetson. Цель состояла в том, чтобы автоматизировать такие задачи, как печать ценников на коробках на основе дизайна и предотвращение ошибок, таких как отправка не тех коробок в больших количествах. В следующих разделах я кратко объясню настройку и код.
Для сборки приложения у нас должно быть следующее оборудование,
- Jetson AGX Xavier (мы могли бы также использовать более простое устройство, такое как nano)
- USB-камера (Logitech)
- Мышь и клавиатура
- Монитор
- сетевой кабель
И последнее, но не менее важное: хост-система Linux для установки JetPack и загрузки Jetson AGX Xavier. NVIDIA также предоставляет возможность загрузки с SD-карты. Чтобы узнать больше, прочитайте следующую ссылку,
Подключите их и начните с установки JetPack через SDK Manager. Я установил его по официальной ссылке (https://developer.nvidia.com/embedded/jetpack) и имею следующие версии JetPack и Cuda,
- Версия реактивного ранца: 5.1.1 [L4T 35.3.1]
- Версия Cuda: инструменты компиляции Cuda, выпуск 11.4, V11.4.315.
После этого мы можем использовать образ докера от NVIDIA, чтобы запустить контейнер среды выполнения и приступить к практической работе.
Ссылка: https://catalog.ngc.nvidia.com/orgs/nvidia/containers/l4t-base
Обратите внимание, что после того, как мы настроили наш Jetson для загрузки, мы можем разработать приложение двумя способами. Один через режим ssh, а другой через обычный ПК, например настройку. Я предпочел обычный подход, так как был более склонен к созданию приложения Python на основе Tkinter, которому требовалось устройство отображения для тестирования пользовательского интерфейса.
Структура папок проекта приведена ниже,
Мы будем использовать вспомогательный класс ImageClassificationDataset из helpers/dataset.py, который сохраняет изображения в соответствующие папки меток с желаемым именем набора данных.
class ImageClassificationDataset(torch.utils.data.Dataset): def __init__(self, directory, categories, transform=None): self.categories = categories self.directory = directory self.transform = transform self._refresh() def __len__(self): return len(self.annotations) def __getitem__(self, idx): ann = self.annotations[idx] image = cv2.imread(ann['image_path'], cv2.IMREAD_COLOR) image = PIL.Image.fromarray(image) if self.transform is not None: image = self.transform(image) return image, ann['category_index'] def _refresh(self): self.annotations = [] for category in self.categories: category_index = self.categories.index(category) for image_path in glob.glob(os.path.join(self.directory, category, '*.jpg')): self.annotations += [{ 'image_path': image_path, 'category_index': category_index, 'category': category }] def save_entry(self, image, category): """Saves an image in BGR8 format to dataset for category""" if category not in self.categories: raise KeyError('There is no category named %s in this dataset.' % category) filename = str(uuid.uuid1()) + '.jpg' category_directory = os.path.join(self.directory, category) if not os.path.exists(category_directory): subprocess.call(['mkdir', '-p', category_directory]) image_path = os.path.join(category_directory, filename) cv2.imwrite(image_path, image) self._refresh() return image_path def get_count(self, category): i = 0 for a in self.annotations: if a['category'] == category: i += 1 return i
Как видите, конструктор принимает параметр каталога, соответствующий имени набора данных, а затем параметр категории, определяющий метки, которые должны быть размещены в наборе данных.
Разобравшись с помощником проекта и набора данных, давайте перейдем к сбору данных с последующим обучением и классификацией.
Сбор данных
Чтобы все модули использовали одну и ту же конфигурацию, я определил config.json следующим образом:
{ "config": { "task": "diff", "datasets": [ "2023q1", "2023q2" ], "labels": [ "None", "Floral101", "Floral102" ] } }
После этого я также создал функцию чтения внутри файла read.py,
from dataset import ImageClassificationDataset import torchvision.transforms as transforms import json def read_dataset(): f = open('config.json') json_data = json.load(f) TRANSFORMS = transforms.Compose([ transforms.ColorJitter(0.2, 0.2, 0.2, 0.2), transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) datasets = {} for name in json_data["config"]["datasets"]: datasets[name] = ImageClassificationDataset(json_data["config"]["task"] + '_' + name, json_data["config"]["labels"], TRANSFORMS) dataset = datasets[json_data["config"]["datasets"][0]] print(json_data) print(json_data["config"]["task"]) print("{0} task with {1} categories defined".format(json_data["config"]["task"], json_data["config"]["labels"])) categories = json_data["config"]["labels"] return (datasets, dataset)
Эта функция чтения использует файл json для создания каталога набора данных как «diff_2023q1» и пометки папок как «Нет», «Цветочный101», «Цветочный102» соответственно.
Вызов функции read_dataset подготовит нас к созданию следующей структуры папок. По мере сбора данных с помощью нашего приложения для сбора данных мы будем больше понимать структуру набора данных.
Ниже приведен код из data_collection.py,
#append src folder to read dataset.py from it import sys sys.path.append('helpers') import tkinter from tkinter import * from tkinter import ttk import numpy as np from PIL import Image, ImageTk import cv2 from jetcam.usb_camera import USBCamera from jetcam.utils import bgr8_to_jpeg import os from tkinter import messagebox from dataset import ImageClassificationDataset import torchvision.transforms as transforms from reader import read_dataset global cam global currFrame global datasets global dataset global countVar global root def prepare(): global datasets global dataset datasets, dataset = read_dataset() def get_label_selection(): print("Dataset label: {}".format(value_inside.get())) return value_inside.get() def update_image(change): global currFrame image = change['new'] currFrame = image def start(): global frame global cam global currFrame global countVar option = get_label_selection() if option not in dataset.categories: messagebox.showerror('Error', 'Please select a label!') return countVar.set("Image count : {0}".format(dataset.get_count(option))) # cam = cv2.VideoCapture(0) #cv2.namedWindow("Experience_in_AI camera") while cam.running: frame = currFrame #cam.read() #Update the image to tkinter... frame=cv2.cvtColor(frame,cv2.COLOR_BGR2RGB) PIL_image = Image.fromarray(np.uint8(frame)).convert('RGB') img_update = ImageTk.PhotoImage(PIL_image) paneeli_image.configure(image=img_update) paneeli_image.image=img_update paneeli_image.update() if frame.shape[0] == 0: print("failed to grab frame") break def stop(): global cam global root cam.running = False cam.unobserve(update_image, names='value') root.quit() print("Stopped!") def save(event=None): global dataset global cam option = get_label_selection() print("save called") dataset.save_entry(cam.value, get_label_selection()) countVar.set("Image count : {0}".format(dataset.get_count(option))) def on_select(event): print("on_select called") option = get_label_selection() countVar.set("Image count : {0}".format(dataset.get_count(option))) # create_dataset_folder() prepare() cam = USBCamera(width=224, height=224, capture_width=640, capture_height=480, capture_device=0) image = cam.read() cam.running = True cam.observe(update_image, names='value') print(image.shape) root=tkinter.Tk() root.title("Dataset creator app") countVar = StringVar() countVar.set("Image count : {0}".format("-")) frame=np.random.randint(0,255,[100,100,3],dtype='uint8') img = ImageTk.PhotoImage(Image.fromarray(frame)) paneeli_image=tkinter.Label(root) #,image=img) paneeli_image.grid(row=0,column=0,columnspan=3,pady=1,padx=10) # Variable to keep track of the option # selected in OptionMenu value_inside = tkinter.StringVar(root) # Set the default value of the variable value_inside.set("Select a Label") # Create the optionmenu widget and passing # the options_list and value_inside to it. question_menu = tkinter.OptionMenu(root, value_inside, *dataset.categories, command = on_select) question_menu.grid(row=1,column=1,pady=1,padx=10) component_height=5 startButton=tkinter.Button(root,text="Start",command=start,height=5,width=20) startButton.grid(row=1,column=0,pady=10,padx=10) startButton.config(height=1*component_height,width=10) component_height=5 stopButton=tkinter.Button(root,text="Exit",command=stop,height=5,width=20) stopButton.grid(row=1,column=2,pady=10,padx=10) stopButton.config(height=1*component_height,width=10) label = Label(root, textvariable=countVar) font=('Calibri 14 bold') label.grid(row=2,column=1,pady=10,padx=10) label = Label(root, text="Use spacebar to save a snapshot") font=('Calibri 12 bold') label.grid(row=3,column=1,pady=10,padx=10) root.bind("<space>", save) root.mainloop()
- «prepare()» создает дескриптор для организации изображений под соответствующими метками.
- «get_label_selection()» возвращает текущую выбранную опцию в виджете, которая является интересующей меткой.
- Функция «update_image()» прикреплена в качестве функции наблюдателя к экземпляру камеры jetcam, которая вызывается при каждом кадре. Эта функция создает ссылку на текущий кадр для виджета изображения tkinter для отображения текущего кадра (для прямой трансляции).
- Функция «start()» запускает цикл, который обновляет виджет изображения текущим кадром.
- Функция «stop()» удаляет текущий наблюдатель кадра из экземпляра камеры и завершает работу приложения.
- Функция «сохранить ()» сохраняет изображение в соответствующую папку меток в наборе данных и сопоставляется с «слушателем пробела». Действие сохранения происходит каждый раз, когда нажимается пробел.
- Функция «on_select()» — это команда, сопоставленная с полем параметров, которая вызывается каждый раз, когда параметр изменяется посредством взаимодействия.
- Остальной код проходит через создание пользовательского интерфейса приложения tkinter.
Запустив data_collection.py, мы должны визуализировать следующий результат:
Обучение
Чтобы обучить нашу модель классификации изображений, мы будем использовать архитектуру ResNet18.
Ниже приведен код train.py,
import sys sys.path.append('helpers') import queue import tkinter from tkinter import * from tkinter import ttk from tkinter import messagebox import numpy as np from PIL import Image, ImageTk import cv2 from dataset import ImageClassificationDataset from reader import read_dataset import torch import torchvision import torchvision.transforms as transforms import torch.nn.functional as F import os import json import threading import time from utils import preprocess global modelPathTF global epochTF global modelPathVar global model global device global dataset global lossValVar global accValVar global trainButton global progBar global root global the_queue global trainThread def refresh_data(): global the_queue global root global accValVar global lossValVar global progBar global modelPathTF global trainThread #print("thread status - ", trainThread.is_alive()) if not trainThread.is_alive(): save_model(modelPathTF.get()) return # refresh the GUI with new data from the queue while not the_queue.empty(): key, data = the_queue.get() #print("value from queue : {0}, {1}", key, data) if key == "prog": progBar["value"] = data elif key == "loss": lossValVar.set(data) elif key == "accu": accValVar.set(data) # timer to refresh the gui with data from the asyncio thread root.after(100, refresh_data) # called only once! def start(): global modelPathTF global epochTF global trainButton global the_queue global root global trainThread print("---") if int(epochTF.get()) <= 0: messagebox.showerror('Error', 'Provide an epoch value greater than 0') return trainButton.config(state="disabled") root.after(1, train_prepare(modelPathTF.get())) root.after(100, refresh_data); trainThread = threading.Thread(target=train, args=(the_queue, int(epochTF.get()), True)) trainThread.start() def stop(): global root print("---") root.quit() print("Stopped!") def launch(): global modelPathTF global epochTF global modelPathVar global lossValVar global accValVar global trainButton global progBar global root root=tkinter.Tk() root.title("Train classifier") countVar = StringVar() countVar.set("Image count : {0}".format("-")) accValVar = StringVar() lossValVar = StringVar() valueInside = tkinter.StringVar(root) valueInside.set("Select a Label") label = Label(root, text="Model Path") font=('Calibri 12 bold') label.grid(row=0,column=0,pady=10,padx=10) modelPathVar = tkinter.StringVar() modelPathTF = tkinter.Entry( root, textvariable=modelPathVar) modelPathVar.set("models/model_v1.pth") # entry1.place(x = 80, y = 50) modelPathTF.grid(row=0,column=1,pady=10,padx=5) label = Label(root, text="Epocs") font=('Calibri 12 bold') label.grid(row=1,column=0,pady=10,padx=10) epocValVar = tkinter.StringVar() epochTF = Entry(root, textvariable=epocValVar) epocValVar.set("0") # entry1.place(x = 80, y = 50) epochTF.grid(row=1,column=1,pady=10,padx=5,columnspan=1) label = Label(root, text="Epocs") font=('Calibri 12 bold') label.grid(row=1,column=0,pady=10,padx=10) label = Label(root, text="Current epoch") font=('Calibri 12 bold') label.grid(row=2,column=0,pady=10,padx=10) progBar = ttk.Progressbar(root,orient=HORIZONTAL, length=200,mode="determinate") progBar.grid(row=2,column=1,pady=10,padx=5) progBar['value']=0 accLabel = Label(root, text="Accuracy") font=('Calibri 12 bold') accLabel.grid(row=3,column=0,pady=10,padx=10) accValLabel = Label(root, textvariable=accValVar) accValVar.set("{0}".format("-")) font=('Calibri 12 normal') accValLabel.grid(row=3,column=1,pady=10,padx=10) lossLabel = Label(root, text="Loss") font=('Calibri 12 bold') lossLabel.grid(row=4,column=0,pady=10,padx=10) lossValLabel = Label(root, textvariable=lossValVar) lossValVar.set("{0}".format("-")) font=('Calibri 12 normal') lossValLabel.grid(row=4,column=1,pady=10,padx=10) componentHeight=2 trainButton=tkinter.Button(root,text="Train",command=start,height=5,width=20) trainButton.grid(row=5, column=0,sticky='W',pady=10,padx=10) trainButton.config(height=1*componentHeight,width=5, padx=50) componentHeight=2 stopButton=tkinter.Button(root,text="Exit",command=stop,height=5,width=20) stopButton.grid(row=5, column=1,sticky='E',pady=10,padx=10) stopButton.config(height=1*componentHeight,width=5, padx=50) root.resizable(0, 0) root.mainloop() def load_model(path): global model model.load_state_dict(torch.load(path)) def save_model(path): global model parDir = "models" isExist = os.path.exists(parDir) if not isExist: # Create a new directory because it does not exist os.makedirs(parDir) torch.save(model.state_dict(), path) def train_prepare(modelPath): print(type(modelPath), modelPath) global model global device global dataset time.sleep(1) datasets, dataset = read_dataset() print("fc layer output - {0}".format(len(dataset.categories))) device = torch.device('cuda') # RESNET 18 model = torchvision.models.resnet18(pretrained=True) model.fc = torch.nn.Linear(512, len(dataset.categories)) model = model.to(device) isExist = os.path.exists(modelPath) if isExist: load_model(modelPath) else: print("model not found in current path, may be this is the first time you are about to train!"); # display(model_widget) print("model configured and model_widget created") def train(the_queue, epochs, is_training): global BATCH_SIZE, LEARNING_RATE, MOMENTUM, model, dataset, optimizer global device print("all data types - ", type(the_queue), type(epochs), type(is_training)) total_epocs = epochs BATCH_SIZE = 3 optimizer = torch.optim.Adam(model.parameters()) elspased_epocs = 0 try: train_loader = torch.utils.data.DataLoader( dataset, batch_size=BATCH_SIZE, shuffle=True ) time.sleep(1) if is_training: model = model.train() else: model = model.eval() while total_epocs > 0: i = 0 sum_loss = 0.0 error_count = 0.0 for images, labels in iter(train_loader): # send data to device images = images.to(device) labels = labels.to(device) if is_training: # zero gradients of parameters optimizer.zero_grad() # execute model to get outputs outputs = model(images) # compute loss loss = F.cross_entropy(outputs, labels) if is_training: # run backpropogation to accumulate gradients loss.backward() # step optimizer to adjust parameters optimizer.step() # increment progress error_count += len(torch.nonzero(outputs.argmax(1) - labels).flatten()) count = len(labels.flatten()) i += count sum_loss += float(loss) # print("loss actual - {0}", sum_loss / i) # print("accuracy actual - {0}", 1.0 - error_count / i) the_queue.put(("accu", "{0}".format(1.0 - error_count / i))) the_queue.put(("loss", "{0}".format(sum_loss / i))) # lossValVar.set("{0}".format(sum_loss / i)) # accValVar.set("{0}".format(1.0 - error_count / i)) elspased_epocs = elspased_epocs + 1 if is_training: total_epocs = total_epocs - 1 print("elspased_epocs", elspased_epocs) # epochs_widget.value = epochs_widget.value - 1 progBar['value'] = elspased_epocs * (100 / epochs) else: break except Exception as ex: print(ex) pass model = model.eval() trainThread = None the_queue = queue.Queue() launch()
- Функция «refresh_data ()» вызывается tkinter, а затем рекурсивно, если поток обучения не завершится. Такая настройка нужна для того, чтобы обновить элементы пользовательского интерфейса tkinter. Нам важно отметить, что refresh_data() использует очередь. Почему? Это связано с тем, что tkinter работает в основном потоке и, следовательно, его виджеты не могут быть изменены из отдельного потока. Из-за этого я прибегал к использованию очереди, в которой обучающий поток помещает данные в очередь, а refresh_data() считывает их для обновления виджетов пользовательского интерфейса.
- Функция «start()» создает дескриптор для чтения изображений под соответствующими метками из набора данных, затем регистрирует refresh_data() для рекурсивного вызова, а затем запускает обучающий поток для обучения модели.
- Функция «launch()» создает пользовательский интерфейс tkinter и регистрирует переменные и команды виджета. Пользовательский интерфейс предлагает возможность указать путь для сохранения модели. Кроме того, он также позволяет указать количество эпох для обучения модели.
- Функция «save_model()» сохраняет модель после того, как обучающий поток завершает свою работу по обучению и завершает работу.
- Функция «train_prepare()» подготавливает приложение к обучению, создавая ссылку на предварительно обученную модель ResNet18, а также инициализируя torch в режиме Cuda для обучения.
- Функция «train()» обучает модель заданному количеству итераций и обновляет информацию о ходе выполнения в очереди, такую как точность и потери на итерацию. Основной поток берет данные из очереди и соответствующим образом обновляет пользовательский интерфейс.
- Пожалуйста, обратите внимание на потери и точность. Потери должны уменьшаться со временем, а точность должна постоянно улучшаться, чтобы достичь хорошего уровня. Если мы видим отклонения, мы должны пересмотреть наш набор данных.
- После успешного обучения мы должны увидеть файл с именем model_v1.pth в папке моделей.
После запуска файла train.py мы должны увидеть следующие результаты:
Классификация
Мы будем использовать прямую трансляцию с камеры для классификации дизайнов коробок в режиме реального времени.
Ниже приведен код classify.py,
#append src folder to read dataset.py from it import sys sys.path.append('helpers') import tkinter from tkinter import * from tkinter import ttk import numpy as np from PIL import Image, ImageTk import cv2 from jetcam.usb_camera import USBCamera from jetcam.utils import bgr8_to_jpeg import os import threading import time import queue from utils import preprocess from tkinter import messagebox from dataset import ImageClassificationDataset import torch import torchvision import torchvision.transforms as transforms import torch.nn.functional as F from reader import read_dataset global frame global cam global currFrame global classificationVar global classifyThread global root global theQueue def update_image(change): global currFrame image = change['new'] currFrame = image def refresh_data(): global theQueue global root global classificationVar global classifyThread #print("thread status - ", classifyThread.is_alive()) if not classifyThread.is_alive(): return # refresh the GUI with new data from the queue while not theQueue.empty(): key, data = theQueue.get() #print("value from queue : {0}, {1}", key, data) if key == "result": classificationVar.set(data) # timer to refresh the gui with data from the asyncio thread root.after(100, refresh_data) # called only once! def start(): global frame global cam global currFrame global classificationVar global root global classifyThread startButton.config(state="disabled") datasets, dataset = read_dataset() device = torch.device('cuda') model = torchvision.models.resnet18(pretrained=True) model.fc = torch.nn.Linear(512, len(dataset.categories)) model = model.to(device) model.load_state_dict(torch.load(modelPathTF.get())) root.after(100, refresh_data); classifyThread = threading.Thread(target=live, args=((cam, dataset, model, theQueue))) classifyThread.start() while cam.running: frame = currFrame #cam.read() #Update the image to tkinter... frame=cv2.cvtColor(frame,cv2.COLOR_BGR2RGB) PIL_image = Image.fromarray(np.uint8(frame)).convert('RGB') img_update = ImageTk.PhotoImage(PIL_image) imageView.configure(image=img_update) imageView.image=img_update imageView.update() if frame.shape[0] == 0: print("failed to grab frame") break def live(camera, dataset, model, theQueue): while cam.running: image = camera.value preprocessed = preprocess(image) output = model(preprocessed) output = F.softmax(output, dim=1).detach().cpu().numpy().flatten() category_index = output.argmax() #print("---prediction---") predictedclass = dataset.categories[category_index] #print(predictedclass) for i, score in enumerate(list(output)): #print(i, score) if dataset.categories[i] == predictedclass: rounded = round(float(score), 2) # theQueue.put(("result", "{0} @ {1}".format(predictedclass, rounded))) theQueue.put(("result", "{0}".format(predictedclass))) #print("---end---") def stop(): global cam global root cam.running = False cam.unobserve(update_image, names='value') root.quit() print("Stopped!") classifyThread = None theQueue = queue.Queue() cam = USBCamera(width=224, height=224, capture_width=640, capture_height=480, capture_device=0) image = cam.read() cam.running = True cam.observe(update_image, names='value') print(image.shape) root=tkinter.Tk() root.title("Classifier app") frame=np.random.randint(0,255,[100,100,3],dtype='uint8') img = ImageTk.PhotoImage(Image.fromarray(frame)) imageView=tkinter.Label(root) #,image=img) imageView.grid(row=0,column=0,columnspan=3,pady=1,padx=10) modelPathVar = tkinter.StringVar() modelPathTF = tkinter.Entry( root, textvariable=modelPathVar) modelPathVar.set("models/model_v1.pth") # entry1.place(x = 80, y = 50) modelPathTF.grid(row=1,column=1,pady=10,padx=5) label = Label(root, text="Classification result") font=('Calibri 14 bold') label.grid(row=2,column=1,pady=0,padx=10) classificationVar = tkinter.StringVar() label = Label(root, textvariable=classificationVar) classificationVar.set("Waiting...") font=('Calibri 12') label.config(fg="#0000FF") label.config(bg="yellow") label.grid(row=3,column=1,pady=10,padx=10) component_height=2 startButton=tkinter.Button(root,text="Start",command=start,height=5,width=20) startButton.grid(row=4,column=0,pady=10,padx=10) startButton.config(height=1*component_height,width=5) component_height=2 stopButton=tkinter.Button(root,text="Exit",command=stop,height=5,width=20) stopButton.grid(row=4,column=2,pady=10,padx=10) stopButton.config(height=1*component_height,width=5) root.mainloop()
- Функция «update_image()» сопоставляется с экземпляром камеры jetcam, который вызывается для каждого кадра. Эта функция создает ссылку на текущий кадр для виджета изображения tkinter для отображения прямой трансляции.
- Функция «refresh_data ()» вызывается tkinter, а затем рекурсивно, если поток классификации не завершится. Он запускается рекурсивно для обновления элементов пользовательского интерфейса tkinter. Функция refresh_data() считывает очередь, обновленную потоком классификации, для обновления виджетов пользовательского интерфейса.
- Функция «start()» запускает цикл, который обновляет виджет изображения текущим кадром. Затем Torch инициализируется в режиме Cuda с последующей загрузкой обученной модели ResNet18 по пути, указанному в виджете ввода. Наконец, поток классификации создается и запускается.
- Функция «live()» берет текущий кадр камеры и отправляет его в модель в нужной форме, чтобы получить результат логического вывода. Результат классификации и оценка помещаются в очередь для обновления пользовательского интерфейса с помощью функции refresh_data().
- Функция «stop()» удаляет текущий наблюдатель кадра из экземпляра камеры и завершает работу приложения.
- Остальной код проходит через создание пользовательского интерфейса приложения tkinter.
- Обратите внимание, что метка с желтой подсветкой указывает на текущий результат классификации.
Запуск classify.py дает следующий результат:
Я надеюсь, что это руководство послужит хорошей отправной точкой для практики нашего первого проекта машинного обучения с использованием Jetson.
Я добавил ссылку на код проекта в github. Не стесняйтесь разветвлять его и использовать повторно. Любые вопросы, которые у вас могут возникнуть, пишите, я постараюсь ответить на них в выходные.
Я также добавил демо-видео, набор данных и модель на,
Добрый день!