soft-analytics-02/plot_acc.py

17 lines
415 B
Python

import os.path
import pandas as pd
from train.finetune import plot_loss_acc
ROOT = os.path.dirname(__file__)
def main():
df = pd.read_csv(os.path.join(ROOT, 'models', 'final', 'stats.csv'))
plot_loss_acc(df['train_loss'].tolist(), df['val_loss'].tolist(), df['train_acc'].tolist(), df['val_acc'].tolist(),
os.path.join(ROOT, 'models', 'final'))
if __name__ == "__main__":
main()