常用功能函数

发布时间:2024-12-20 16:01

1. 读取.txt打乱顺序保存到test.csv

def create_csv(txt_path, csv_path):
    lists = pd.read_csv(txt_path, sep=r"\t", header=None)
    lists = lists.sample(frac=1)
    lists.to_csv(csv_path, index=None)
    print("Finish save csv")

2. 准确率 & 损失结果可视化

# 解决中文乱码问题
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False

def plot_loss(train_loss, val_loss):
    plt.plot(train_loss, label='train_loss')
    plt.plot(val_loss, label='val_loss')
    plt.legend(loc='best')
    plt.ylabel('loss')
    plt.xlabel('epoch')
    plt.title("训练集和验证集loss值得对比图")
    plt.savefig('results/loss.png')
    plt.show()


def plot_acc(train_acc, val_acc):
    plt.plot(train_acc, label='train_acc')
    plt.plot(val_acc, label='val_acc')
    plt.legend(loc='best')
    plt.ylabel('acc')
    plt.xlabel('epoch')
    plt.title("训练集和验证集acc值得对比图")
    plt.savefig('results/acc.png')
    plt.show()


def plot_results(epochs, train_acc, train_loss, test_acc, test_loss):
    x = np.arange(epochs)

    plt.plot(x, train_acc, label='train_acc')
    plt.plot(x, train_loss, label='train_loss')
    plt.plot(x, test_acc, label='test_acc')
    plt.plot(x, test_loss, label='test_loss')

    plt.title("Results", fontsize=15)
    plt.xlabel("X", fontsize=13)
    plt.ylabel("Y", fontsize=13)
    plt.legend()
    plt.savefig('results/result.png')
    plt.show()

3. 保存文件路径到txt

from glob import glob
def save_txt(data_path, save_path):
    files = glob(data_path + '/*.jpg')
    lists = sorted(files)  # 所有图片文件路径

    a = open(save_path, "w", encoding='UTF-8')
    for i in range(len(lists)):
        a.write(lists[i])
        a.write('\n')
    a.close()

4. 读入txt

def ReadTxt(rootdir='test.txt'):
    lines = []
    with open(rootdir, 'r') as file_to_read:
        while True:
            line = file_to_read.readline()
            if not line:
                break
            line = line.strip('\n')
            lines.append(line)
    return lines

5. 获取文件名写入txt

path = '/data/'
files = []
a = open("test.txt", "w", encoding='UTF-8')
for item in os.listdir(path):
    if item.split('.')[-1].lower() in ['jpg', 'jpeg', 'png']:
        a.write(str(item.split('.')[:-1]) + '\n')
a.close()

ItVuer - 免责声明 - 关于我们 - 联系我们

本网站信息来源于互联网,如有侵权请联系:561261067@qq.com

桂ICP备16001015号