Робочий процес 🔥PyTorch

Технології комп’ютерного зору

Автор
Приналежність

Ігор Мірошниченко

КНУ імені Тараса Шевченка

  1. Створіть набір даних прямої лінії, використовуючи формулу лінійної регресії (bias + weight * X), а також додайте трохи випадкового шуму (наприклад, torch.randn()). При цьому самостійно оберіть початкове значення генератора випадкових чисел.
    • Випадково оберіть bias і weight. Точок має бути 100+.
    • Розділіть дані на 80% для навчальної та 20% для тестової вибірки.
    • Побудуйте графік даних.

  1. Створіть модель PyTorch, створивши підклас nn.Module.
    • Всередині повинен бути випадково ініціалізований nn.Parameter() з requires_grad=True, один для weights і один для bias.
    • Реалізуйте метод forward() для обчислення функції лінійної регресії, яку ви використовували для створення набору даних в 1.
    • Після побудови моделі створіть її екземпляр і перевірте його state_dict().
Примітка

Якщо ви хочете використовувати nn.Linear() замість nn.Parameter(), ви можете це зробити.

  1. Створіть функцію втрат та оптимізатор, використовуючи відповідно nn.L1Loss() та torch.optim.SGD(params, lr).
    • Встановіть швидкість навчання оптимізатора на 0.01, а параметрами для оптимізації мають бути параметри моделі, яку ви створили в пункті 2.
    • Напишіть цикл навчання, щоб виконати відповідні кроки навчання протягом 300 епох.
    • Цикл навчання повинен тестувати модель на тестовому наборі даних кожні 20 епох.
  2. Зробіть прогнози за допомогою навченої моделі на тестових даних.
    • Візуалізуйте ці прогнози на тлі оригінальних навчальних і тестових даних.
Примітка

Якщо ви хочете використовувати бібліотеки, що не підтримують CUDA, такі як matplotlib, для побудови графіків, вам може знадобитися переконатися, що прогнози не виконуються на GPU.

  1. Збережіть state_dict() вашої навченої моделі у файлі.
    • Створіть новий екземпляр класу моделі, який ви створили в пункті 2, і завантажте в нього щойно збережений state_dict().
    • Виконайте прогнозування на ваших тестових даних за допомогою завантаженої моделі та переконайтеся, що вони відповідають прогнозам оригінальної моделі з пункту 4.