[自用代码]将原始数据集进行划分成训练集、验证集和测试集,并计算权重

发布时间:2023-11-19 13:00

# coding: utf-8
"""
    将原始数据集进行划分成训练集、验证集和测试集
"""

import os
import glob
import random
import shutil

dataset_dir = 'img'
train_dir = 'train-8/'
valid_dir = 'valid-2/'
#test_dir = './data/test/'

train_per = 0.8
valid_per = 0.2
#test_per = 0.1


def makedir(new_dir):
    if not os.path.exists(new_dir):
        os.makedirs(new_dir)

def cal_weight(ll: list) -> list:
    ll = [sum(ll)/x for x in ll]
    ll = [round(x/max(ll),3) for x in ll]
    return ll 
    
if __name__ == '__main__':

    class_weight = []
    for root, dirs, files in os.walk(dataset_dir):
        for sDir in dirs:
            imgs_list = glob.glob(os.path.join(root, sDir)+'/*.jpg')
            random.seed()
            random.shuffle(imgs_list)
            imgs_num = len(imgs_list)
            class_weight.append(imgs_num)

            train_point = int(imgs_num * train_per)
            valid_point = int(imgs_num * (train_per + valid_per))

            for i in range(imgs_num):
                if i < train_point:
                    out_dir = train_dir + sDir + '/'
                elif i < valid_point:
                    out_dir = valid_dir + sDir + '/'
                else:
                    out_dir = test_dir + sDir + '/'

                makedir(out_dir)
                out_path = out_dir + os.path.split(imgs_list[i])[-1]
                shutil.copy(imgs_list[i], out_path)

            print('Class:{}, train:{}, valid:{}, test:{}'.format(sDir, train_point, valid_point-train_point, imgs_num-valid_point))
    print(cal_weight(class_weight))

ItVuer - 免责声明 - 关于我们 - 联系我们

本网站信息来源于互联网,如有侵权请联系:561261067@qq.com

桂ICP备16001015号