>

>

>

KI erzeugt Trainingsdaten für KI

KI erzeugt Trainingsdaten für KI

KI erzeugt Trainingsdaten für KI

Generierung synthetischer Trainingsdaten mit Generative Adversarial Network (GAN): Proof of Concept anhand des MNIST-Datensatzes

Generierung synthetischer Trainingsdaten mit Generative Adversarial Network (GAN): Proof of Concept anhand des MNIST-Datensatzes

Einleitung

Bild­erkennung (engl.: image recognition) wird heutzutage zunehmend zu einem zentralen Bestandteil zahlreicher Anwendungen. Darunter fallen beispielsweise Straßenschilderkennung in modernen Autos, KI gestützte Qualitätssicherung in Warenherstellungsprozessen oder die medizinische Diagnostik im Gesundheitswesen, etwa bei der Analyse von Röntgenbildern oder der Erkennung von Tumoren. Die Performance neuronaler Netze die zu dieser Bilderkennung trainiert wurden hängt jedoch stark von dem vorhandenen Trainingsdatensatz ab. Sind in diesem zum Beispiel nur wenige Bilder eines gewissen Hautkrebstypus vorhanden, so sinkt auch die Wahrscheinlichkeit, dass die KI diese Art des Hautkrebs in der echten Anwendung erkennt.

Um dies vorzubeugen gibt es verschiedene Möglichkeiten den Trainingsdatensatz zu erweitern. Diese werden unter dem Oberbegriff Data Augementation zusammengefasst. In der Bild­erkennung gehört das Drehen und Spiegeln von Bildern zu den am häufigsten eingesetzten Methoden der Data Augmentation, wodurch ein einzelnes Bild zu vielen Trainingsbildern erweitert werden kann.

In diesem Projekt verfolgen wir das Ziel, Data Augmentation durch die künstliche Generierung von Bildern umzusetzen. Dazu entwickeln wir ein generatives Netzwerk, genauer ein Generative Adversarial Network (GAN) in PyTorch, trainieren es auf dem vorhandenen Datensatz, erzeugen damit künstliche Bilder und erweitern den ursprünglichen Datensatz anschließend mit diesen synthetischen Daten. Den Erfolg dieses Ansatzes messen wir, indem wir ein Bilderkennungsnetz sowohl mit dem ursprünglichen als auch mit dem erweiterten Datensatz trainieren und anschließend die erzielten Genauigkeiten vergleichen.

Wir haben uns für einen der bekanntesten Bilddatensätze, den Standard‑Datensatz MNIST, entschieden. Dieser enthält handgeschriebene Zahlen von 0 bis 9 in Graustufen auf 28×28 Pixeln. Insgesamt umfasst der Datensatz 60.000 Trainings- und 10.000 Testbilder. Für uns steht der Lerneffekt und nicht die Entwicklung eines konkurrenzfähigen Netzwerks im Vordergrund. Das bedeutet, dass wir primär an einem Proof of Concept interessiert sind. Alle benutzten Netzwerke wurden selbst geschrieben und von Grund auf trainiert. Gearbeitet wurde auf der Platform Google Colab, da man dort einen zeitlich bergenzten Zugang zu kostenlosen GPUs hat, was für ein effizientes Training unbedingt nötig ist.

Methodik

Ein Generative Adversarial Network (GAN) besteht aus zwei neuronalen Netzen, dem Generator und dem Discriminator, die als Gegenspieler agieren. Zunächst wird der Discriminator mit Bildern aus dem echten Datensatz trainiert, um diese korrekt als echt zu erkennen. Der Generator erhält einen zufälligen Noise-Vektor als Eingabe und erzeugt daraus anfangs völlig zufällige, verpixelte Bilder im gleichen Format wie der ursprüngliche Datensatz. Diese künstlichen Bilder werden anschließend dem Discriminator präsentiert. Während des Trainings lernen beide Netzwerke gleichzeitig: der Discriminator, die generierten Bilder als künstlich zu erkennen, und der Generator, den Discriminator zu täuschen, sodass dieser die künstlichen Bilder fälschlicherweise als echt klassifiziert.

Unser initial gebautes und trainiertes GAN erzeugte zunächst zufällige Zahlen von 0 bis 9. Deshalb haben wir es zu einem Conditional GAN erweitert, also einem Netzwerk, dem wir eine Bedingung (engl.: condition) übergeben können. In unserem Fall legt die Bedingung fest, welche Zahl erzeugt werden soll. Das Titelbild zeigt das Ergebnis: Abhängig von der Nutzerangabe generiert der Generator pro Zeile zunächst Nullen, dann Einsen usw.

Zu diesem GAN haben wir zusätzlich ein Klassifizierungsnetzwerk implementiert, das darauf trainiert werden kann, die verschiedenen MNIST-Zahlen korrekt zu erkennen. Konkret bedeutet dies: Wird dem Netzwerk ein Bild einer „1“ übergeben, soll es diese als „1“ klassifizieren. Die Genauigkeit des Netzwerks wird dabei anhand der Anzahl korrekt klassifizierter Bilder aus dem Testdatensatz bestimmt.

Mit diesem Setup verfügen wir über alle Komponenten für unsere Data-Augmentation-Tests. Die Pipeline gestaltet sich wie folgt: Zunächst laden wir den MNIST-Datensatz herunter. Den Testsatz belassen wir unverändert, während wir vom Trainingssatz nur einen Teil verwenden, beispielsweise 10 % der Originalbilder. Mit diesem reduzierten Trainingsset trainieren wir anschließend unser GAN. Das Training läuft so lange, bis die Frechet Inception Distance (FID), eine Metrik zur Bewertung der Qualität des GAN, ein Plateau erreicht oder wieder zu steigen beginnt. Sobald dies der Fall ist, brechen wir das Training ab, da die vom Generator erzeugten Bilder sonst wieder an Qualität verlieren.

Anschließend nutzen wir das reduzierte MNIST-Set, um unser Klassifizierungsnetzwerk zu trainieren. Das Training läuft hier ebenfalls so lange bis die Genauigkeit der Klassifizierung ein Maximum erreicht, es wird also gestoppt bevor diese wieder abnimmt. Diese Ergebnisse vergleichen wir mit der Genauigkeit eines Klassifizierungsnetzwerks, das sowohl mit den reduzierten MNIST-Bildern als auch mit den künstlich erzeugten Bildern trainiert wurde. Diese Pipeline wird dann für verschiedene Anteile an Originalbildern durchlaufen und verglichen.

Ergebnis

Die Ergebnisse unseres Verfahrens sind im dargestellten Plot zusammengefasst. Deutlich erkennbar ist, dass das Klassifikationsnetz, das zusätzlich mit reduzierten Originaldaten + Data Augmentation trainiert wurde, eine höhere Genauigkeit erreicht als jenes, welches ausschließlich mit den reduzierten Originaldaten trainiert wurde. Der Unterschied in der Genauigkeit ist dabei umso größer, je kleiner der verwendete Anteil des MNIST-Trainingsdatensatzes ist. Für größere Anteile nähern sich die Ergebnisse beider Modelle zunehmend an.

Dieses Verhalten ist zu erwarten: Wird der vollständige Datensatz genutzt, erreicht der Klassifikator bereits eine Genauigkeit von etwa 99,5 %. Damit werden von 10.000 Testbildern nur noch rund 50 falsch klassifiziert. In diesem Bereich ist das Verbesserungspotenzial sehr gering, sodass zusätzliche künstliche Trainingsdaten kaum noch zu einer merklichen Steigerung der Genauigkeit führen können.

Generell fällt die Verbesserung durch Data Augmentation jedoch geringer aus als erwartet. Die größte Verbesserung beobachten wir bei einem Trainingsanteil von 2,5 % des MNIST-Datensatzes; hier werden knapp 40 Testbilder mehr korrekt klassifiziert.

Ein möglicher Grund dafür liegt in der Leistungsfähigkeit des verwendeten GAN, da dieses für die Qualität der künstlich erstellten Bilder verantwortlich ist. Zwar haben wir verschiedene Einstellungen ausprobiert, jedoch keine systematische Analyse der Architektur vorgenommen, etwa hinsichtlich der optimalen Anzahl und Art der verwendeten Layer.

Darüber hinaus spielt auch die verwendete Bewertungsmetrik eine wichtige Rolle. Die Qualität des GAN wird anhand der FID beurteilt. Sollte diese Metrik die Bildqualität nicht zuverlässig widerspiegeln, kann es passieren, dass eine bestimmte Netzwerkkonfiguration als besser bewertet wird, obwohl die erzeugten Bilder aus menschlicher Sicht weniger überzeugend sind als die einer anderen Konfiguration. Da wir in unserem Setup die FID ebenfalls über ein eigens implementiertes und trainiertes Netzwerk berechnen gibt es auch hier vermutlich noch Verbesserungsspielraum.

Die Qualität der verwendeten Klassifikationsnetzwerke spielt in diesem Zusammenhang eine eher untergeordnete Rolle. Zwar ließen sich diese Netzwerke durch weitere Optimierungen vermutlich noch verbessern, jedoch vergleichen wir zwei identische Architekturen miteinander. Es ist daher davon auszugehen, dass Verbesserungen der Netzwerkstruktur sowohl die Genauigkeit des Modells mit Data Augmentation als auch die des Modells ohne Data Augmentation erhöhen würden, ohne den relativen Unterschied zwischen beiden wesentlich zu verändern.

Zusammenfassend lässt sich festhalten, dass wir mit unserer Pipeline eine kleine, aber messbare Verbesserung der Klassifikationsgenauigkeit durch GAN-basierte Data Augmentation beobachten können. Perspektivisch ließe sich das verwendete GAN weiter optimieren, etwa durch eine systematischere Analyse der Netzwerkarchitektur und der Trainingsparameter auf dem MNIST-Datensatz. Darüber hinaus wäre die Anwendung auf realistischere Problemstellungen ein naheliegender nächster Schritt, beispielsweise in der medizinischen Bildanalyse wie der Hautkrebserkennung. Hierfür müssten Generator und Discriminator im Wesentlichen nur an die jeweilige Bildauflösung und die Anzahl der Klassen des verwendeten Datensatzes angepasst werden.

Team

Paula Mors

Ole Körner

Alexander Feike

Mentor:in

Maximilian Hahn

Unsere Partner

Unsere Partner

Unsere Partner