diff --git a/.gitignore b/.gitignore index 7f81bbeb..708d8cc5 100644 --- a/.gitignore +++ b/.gitignore @@ -7,3 +7,5 @@ lib/pycocotools lib/pycocotools/_mask.c lib/pycocotools/_mask.so .idea +data/fontdataset +.DS_Store diff --git a/README.md b/README.md index 7b9011af..77ae752f 100644 --- a/README.md +++ b/README.md @@ -135,7 +135,7 @@ If you find it useful, the ``data/cache`` folder created on my side is also shar **Note**: If you cannot get the reported numbers (79.8 on my side), then probably the NMS function is compiled improperly, refer to [Issue 5](https://github.com/endernewton/tf-faster-rcnn/issues/5). ### Train your own model -1. Download pre-trained models and weights. The current code support VGG16 and Resnet V1 models. Pre-trained models are provided by slim, you can get the pre-trained models [here](https://github.com/tensorflow/models/tree/master/slim#pre-trained-models) and set them in the ``data/imagenet_weights`` folder. For example for VGG16 model, you can set up like: +1. Download pre-trained models and weights. The current code support VGG16 and Resnet V1 models. Pre-trained models are provided by slim, you can get the pre-trained models [here](https://github.com/tensorflow/models/tree/master/research/slim#pre-trained-models) and set them in the ``data/imagenet_weights`` folder. For example for VGG16 model, you can set up like: ```Shell mkdir -p data/imagenet_weights cd data/imagenet_weights diff --git a/data/scripts/fetch_faster_rcnn_models.sh b/data/scripts/fetch_faster_rcnn_models.sh index b5fcab4c..7a3ce1c9 100755 --- a/data/scripts/fetch_faster_rcnn_models.sh +++ b/data/scripts/fetch_faster_rcnn_models.sh @@ -7,6 +7,7 @@ NET=res101 FILE=voc_0712_80k-110k.tgz # replace it with gs11655.sp.cs.cmu.edu if ladoga.graphics.cs.cmu.edu does not work URL=http://ladoga.graphics.cs.cmu.edu/xinleic/tf-faster-rcnn/$NET/$FILE +URL=http://xinlei.sp.cs.cmu.edu/xinleic/tf-faster-rcnn/$NET/$FILE CHECKSUM=cb32e9df553153d311cc5095b2f8c340 if [ -f $FILE ]; then diff --git a/experiments/scripts/demo_fontdataset.sh b/experiments/scripts/demo_fontdataset.sh new file mode 100755 index 00000000..6218f110 --- /dev/null +++ b/experiments/scripts/demo_fontdataset.sh @@ -0,0 +1,10 @@ +#!/bin/bash + +set -x +set -e + +export PYTHONUNBUFFERED="True" + +GPU_ID=$1 + +CUDA_VISIBLE_DEVICES=${GPU_ID} time python tools/demo_fontdataset.py --testimg data/fontdataset --net res101 --model output/res101/fontdataset_trainval/default/res101_faster_rcnn_iter_490000.ckpt --dataset fontdataset_test --index data/fontdataset/test.txt diff --git a/experiments/scripts/test_faster_rcnn.sh b/experiments/scripts/test_faster_rcnn.sh index 7d288c7b..ccd803b4 100755 --- a/experiments/scripts/test_faster_rcnn.sh +++ b/experiments/scripts/test_faster_rcnn.sh @@ -15,6 +15,14 @@ EXTRA_ARGS=${array[@]:3:$len} EXTRA_ARGS_SLUG=${EXTRA_ARGS// /_} case ${DATASET} in + fontdataset) + TRAIN_IMDB="fontdataset_trainval" + TEST_IMDB="fontdataset_test" + STEPSIZE="[350000]" + ITERS=490000 + ANCHORS="[2,3,4,5,6,8,16,32]" + RATIOS="[0.5,1,2]" + ;; pascal_voc) TRAIN_IMDB="voc_2007_trainval" TEST_IMDB="voc_2007_test" diff --git a/experiments/scripts/train_faster_rcnn.sh b/experiments/scripts/train_faster_rcnn.sh index d5de07c0..3f33893e 100755 --- a/experiments/scripts/train_faster_rcnn.sh +++ b/experiments/scripts/train_faster_rcnn.sh @@ -15,6 +15,14 @@ EXTRA_ARGS=${array[@]:3:$len} EXTRA_ARGS_SLUG=${EXTRA_ARGS// /_} case ${DATASET} in + fontdataset) + TRAIN_IMDB="fontdataset_trainval" + TEST_IMDB="fontdataset_test" + STEPSIZE="[350000]" + ITERS=490000 + ANCHORS="[2,3,4,5,6,8,16,32]" + RATIOS="[0.5,1,2]" + ;; pascal_voc) TRAIN_IMDB="voc_2007_trainval" TEST_IMDB="voc_2007_test" @@ -82,4 +90,9 @@ if [ ! -f ${NET_FINAL}.index ]; then fi fi -./experiments/scripts/test_faster_rcnn.sh $@ +#./experiments/scripts/test_faster_rcnn.sh $@ + +./experiments/scripts/demo_fontdataset.sh 0 > demo_11000char_result.txt + + +sudo poweroff diff --git a/fonts/.DS_Store b/fonts/.DS_Store new file mode 100644 index 00000000..5008ddfc Binary files /dev/null and b/fonts/.DS_Store differ diff --git a/fonts/JejuGothic-Regular.ttf b/fonts/JejuGothic-Regular.ttf new file mode 100755 index 00000000..a6886c35 Binary files /dev/null and b/fonts/JejuGothic-Regular.ttf differ diff --git a/fonts/JejuHallasan-Regular.ttf b/fonts/JejuHallasan-Regular.ttf new file mode 100755 index 00000000..522a76df Binary files /dev/null and b/fonts/JejuHallasan-Regular.ttf differ diff --git a/fonts/JejuMyeongjo-Regular.ttf b/fonts/JejuMyeongjo-Regular.ttf new file mode 100755 index 00000000..aeed811a Binary files /dev/null and b/fonts/JejuMyeongjo-Regular.ttf differ diff --git a/fonts/KoPubBatang-Bold.ttf b/fonts/KoPubBatang-Bold.ttf new file mode 100755 index 00000000..62471312 Binary files /dev/null and b/fonts/KoPubBatang-Bold.ttf differ diff --git a/fonts/KoPubBatang-Light.ttf b/fonts/KoPubBatang-Light.ttf new file mode 100755 index 00000000..f4e9a173 Binary files /dev/null and b/fonts/KoPubBatang-Light.ttf differ diff --git a/fonts/KoPubBatang-Regular.ttf b/fonts/KoPubBatang-Regular.ttf new file mode 100755 index 00000000..c212794d Binary files /dev/null and b/fonts/KoPubBatang-Regular.ttf differ diff --git a/fonts/NanumGothic-Bold.ttf b/fonts/NanumGothic-Bold.ttf new file mode 100755 index 00000000..c24b0e73 Binary files /dev/null and b/fonts/NanumGothic-Bold.ttf differ diff --git a/fonts/NanumGothic-ExtraBold.ttf b/fonts/NanumGothic-ExtraBold.ttf new file mode 100755 index 00000000..c85adc7d Binary files /dev/null and b/fonts/NanumGothic-ExtraBold.ttf differ diff --git a/fonts/NanumGothic-Regular.ttf b/fonts/NanumGothic-Regular.ttf new file mode 100755 index 00000000..c14ce884 Binary files /dev/null and b/fonts/NanumGothic-Regular.ttf differ diff --git a/fonts/NanumGothicCoding-Bold.ttf b/fonts/NanumGothicCoding-Bold.ttf new file mode 100755 index 00000000..9aa9ba2a Binary files /dev/null and b/fonts/NanumGothicCoding-Bold.ttf differ diff --git a/fonts/NanumGothicCoding-Regular.ttf b/fonts/NanumGothicCoding-Regular.ttf new file mode 100755 index 00000000..ba77a9da Binary files /dev/null and b/fonts/NanumGothicCoding-Regular.ttf differ diff --git a/fonts/NanumMyeongjo-Bold.ttf b/fonts/NanumMyeongjo-Bold.ttf new file mode 100755 index 00000000..ecac4470 Binary files /dev/null and b/fonts/NanumMyeongjo-Bold.ttf differ diff --git a/fonts/NanumMyeongjo-ExtraBold.ttf b/fonts/NanumMyeongjo-ExtraBold.ttf new file mode 100755 index 00000000..6dd48fbc Binary files /dev/null and b/fonts/NanumMyeongjo-ExtraBold.ttf differ diff --git a/fonts/NanumMyeongjo-Regular.ttf b/fonts/NanumMyeongjo-Regular.ttf new file mode 100755 index 00000000..a47e1a6f Binary files /dev/null and b/fonts/NanumMyeongjo-Regular.ttf differ diff --git a/fonts/NotoSansKR-Black.otf b/fonts/NotoSansKR-Black.otf new file mode 100644 index 00000000..68c9dd8f Binary files /dev/null and b/fonts/NotoSansKR-Black.otf differ diff --git a/fonts/NotoSansKR-Bold.otf b/fonts/NotoSansKR-Bold.otf new file mode 100644 index 00000000..a75bd0af Binary files /dev/null and b/fonts/NotoSansKR-Bold.otf differ diff --git a/fonts/NotoSansKR-Light.otf b/fonts/NotoSansKR-Light.otf new file mode 100644 index 00000000..2a195751 Binary files /dev/null and b/fonts/NotoSansKR-Light.otf differ diff --git a/fonts/NotoSansKR-Medium.otf b/fonts/NotoSansKR-Medium.otf new file mode 100644 index 00000000..7aeeadf8 Binary files /dev/null and b/fonts/NotoSansKR-Medium.otf differ diff --git a/fonts/NotoSansKR-Regular.otf b/fonts/NotoSansKR-Regular.otf new file mode 100644 index 00000000..929719bb Binary files /dev/null and b/fonts/NotoSansKR-Regular.otf differ diff --git a/fonts/NotoSansKR-Thin.otf b/fonts/NotoSansKR-Thin.otf new file mode 100644 index 00000000..9878079e Binary files /dev/null and b/fonts/NotoSansKR-Thin.otf differ diff --git a/lib/datasets/factory.py b/lib/datasets/factory.py index 3283d0b5..f55d9255 100644 --- a/lib/datasets/factory.py +++ b/lib/datasets/factory.py @@ -13,6 +13,7 @@ __sets = {} from datasets.pascal_voc import pascal_voc from datasets.coco import coco +import datasets.fontdataset import numpy as np @@ -39,6 +40,11 @@ name = 'coco_{}_{}'.format(year, split) __sets[name] = (lambda split=split, year=year: coco(split, year)) +# FONT dataset +for split in ['train', 'val', 'trainval', 'test']: + name = 'fontdataset_{}'.format(split) + __sets[name] = (lambda split=split: + datasets.fontdataset.fontdataset(split)) def get_imdb(name): """Get an imdb (image database) by name.""" diff --git a/lib/datasets/fontdataset.py b/lib/datasets/fontdataset.py new file mode 100644 index 00000000..0b64e907 --- /dev/null +++ b/lib/datasets/fontdataset.py @@ -0,0 +1,364 @@ +# -*- coding: utf-8 -*- + +# -------------------------------------------------------- +# Fast R-CNN +# Copyright (c) 2015 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Written by Ross Girshick +# -------------------------------------------------------- + +import os +from datasets.imdb import imdb +import datasets.ds_utils as ds_utils +import xml.etree.ElementTree as ET +import numpy as np +import scipy.sparse +import scipy.io as sio +import utils.cython_bbox +import cPickle +import subprocess +import uuid +from fontdataset_eval import fontdataset_eval +from model.config import cfg +import io +import pdb + + +class fontdataset(imdb): + def __init__(self, image_set, devkit_path=None): + imdb.__init__(self, 'fontdataset_' + image_set) + # self._year = year + self._year = 2007 + self._image_set = image_set + self._devkit_path = self._get_default_path() if devkit_path is None \ + else devkit_path + self._data_path = os.path.join(self._devkit_path, 'fontdataset') + self._classes = self._get_classes() + self._class_to_ind = dict(zip(self.classes, xrange(self.num_classes))) + self._image_ext = '.jpg' + self._image_index = self._load_image_set_index() + # Default to roidb handler + #self._roidb_handler = self.selective_search_roidb + self._roidb_handler = self.gt_roidb + self._salt = str(uuid.uuid4()) + self._comp_id = 'comp4' + + # PASCAL specific config options + self.config = {'cleanup' : True, + 'use_salt' : True, + 'use_diff' : False, + 'matlab_eval' : False, + 'rpn_file' : None, + 'min_size' : 2} + + assert os.path.exists(self._devkit_path), \ + 'fontdataset path does not exist: {}'.format(self._devkit_path) + assert os.path.exists(self._data_path), \ + 'Path does not exist: {}'.format(self._data_path) + + def image_path_at(self, i): + """ + Return the absolute path to image i in the image sequence. + """ + return self.image_path_from_index(self._image_index[i]) + + def image_path_from_index(self, index): + """ + Construct an image path from the image's "index" identifier. + """ + image_path = os.path.join(self._data_path, 'images', + index + self._image_ext) + assert os.path.exists(image_path), \ + 'Path does not exist: {}'.format(image_path) + return image_path + + def _load_image_set_index(self): + """ + Load the indexes listed in this dataset's image set file. + """ + # Example path to image set file: + # self._devkit_path + /VOCdevkit2007/VOC2007/ImageSets/Main/val.txt + image_set_file = os.path.join(self._data_path, + self._image_set + '.txt') + assert os.path.exists(image_set_file), \ + 'Path does not exist: {}'.format(image_set_file) + with open(image_set_file) as f: + image_index = [x.strip() for x in f.readlines()] + return image_index + + def _get_default_path(self): + """ + Return the default path where PASCAL VOC is expected to be installed. + """ + return os.path.join(cfg.DATA_DIR) + # return os.path.join(cfg.DATA_DIR, 'fontdataset') + + def _get_classes(self): + """ + Return the list of classes + """ + # with open(os.path.join(self._data_path, "labels.txt"), 'r') as rf: + with io.open(os.path.join(self._data_path, "labels.txt"), mode='r', encoding='utf-8') as rf: + content = rf.readlines() + + classes = ['__background__'] + [x.strip('\n') for x in content] + return classes + + def gt_roidb(self): + """ + Return the database of ground-truth regions of interest. + + This function loads/saves from/to a cache file to speed up future calls. + """ + cache_file = os.path.join(self.cache_path, self.name + '_gt_roidb.pkl') + if os.path.exists(cache_file): + with open(cache_file, 'rb') as fid: + roidb = cPickle.load(fid) + print '{} gt roidb loaded from {}'.format(self.name, cache_file) + return roidb + + gt_roidb = [self._load_fontdataset_annotation(index) + for index in self.image_index] + with open(cache_file, 'wb') as fid: + cPickle.dump(gt_roidb, fid, cPickle.HIGHEST_PROTOCOL) + print 'wrote gt roidb to {}'.format(cache_file) + + return gt_roidb + + def selective_search_roidb(self): + """ + Return the database of selective search regions of interest. + Ground-truth ROIs are also included. + + This function loads/saves from/to a cache file to speed up future calls. + """ + cache_file = os.path.join(self.cache_path, + self.name + '_selective_search_roidb.pkl') + + if os.path.exists(cache_file): + with open(cache_file, 'rb') as fid: + roidb = cPickle.load(fid) + print '{} ss roidb loaded from {}'.format(self.name, cache_file) + return roidb + + if int(self._year) == 2007 or self._image_set != 'test': + gt_roidb = self.gt_roidb() + ss_roidb = self._load_selective_search_roidb(gt_roidb) + roidb = imdb.merge_roidbs(gt_roidb, ss_roidb) + else: + roidb = self._load_selective_search_roidb(None) + with open(cache_file, 'wb') as fid: + cPickle.dump(roidb, fid, cPickle.HIGHEST_PROTOCOL) + print 'wrote ss roidb to {}'.format(cache_file) + + return roidb + + def rpn_roidb(self): + if int(self._year) == 2007 or self._image_set != 'test': + gt_roidb = self.gt_roidb() + rpn_roidb = self._load_rpn_roidb(gt_roidb) + roidb = imdb.merge_roidbs(gt_roidb, rpn_roidb) + else: + roidb = self._load_rpn_roidb(None) + + return roidb + + def _load_rpn_roidb(self, gt_roidb): + filename = self.config['rpn_file'] + print 'loading {}'.format(filename) + assert os.path.exists(filename), \ + 'rpn data not found at: {}'.format(filename) + with open(filename, 'rb') as f: + box_list = cPickle.load(f) + return self.create_roidb_from_box_list(box_list, gt_roidb) + + def _load_selective_search_roidb(self, gt_roidb): + filename = os.path.abspath(os.path.join(cfg.DATA_DIR, + 'selective_search_data', + self.name + '.mat')) + assert os.path.exists(filename), \ + 'Selective search data not found at: {}'.format(filename) + raw_data = sio.loadmat(filename)['boxes'].ravel() + + box_list = [] + for i in xrange(raw_data.shape[0]): + boxes = raw_data[i][:, (1, 0, 3, 2)] - 1 + keep = ds_utils.unique_boxes(boxes) + boxes = boxes[keep, :] + keep = ds_utils.filter_small_boxes(boxes, self.config['min_size']) + boxes = boxes[keep, :] + box_list.append(boxes) + + return self.create_roidb_from_box_list(box_list, gt_roidb) + + def _load_fontdataset_annotation(self, index): + """ + Load image and bounding boxes info from XML file in the PASCAL VOC + format. + """ + filename = os.path.join(self._data_path, 'annotations', index + '.xml') + parser = ET.XMLParser(encoding='utf-8') + tree = ET.parse(filename, parser=parser) + objs = tree.findall('object') + # if not self.config['use_diff']: + # # Exclude the samples labeled as difficult + # non_diff_objs = [ + # obj for obj in objs if int(obj.find('difficult').text) == 0] + # # if len(non_diff_objs) != len(objs): + # # print 'Removed {} difficult objects'.format( + # # len(objs) - len(non_diff_objs)) + # objs = non_diff_objs + num_objs = len(objs) + + boxes = np.zeros((num_objs, 4), dtype=np.uint16) + gt_classes = np.zeros((num_objs), dtype=np.int32) + overlaps = np.zeros((num_objs, self.num_classes), dtype=np.float32) + # "Seg" area for pascal is just the box area + seg_areas = np.zeros((num_objs), dtype=np.float32) + + # Load object bounding boxes into a data frame. + for ix, obj in enumerate(objs): + bbox = obj.find('bndbox') + # Make pixel indexes 0-based + x1 = float(bbox.find('xmin').text) + y1 = float(bbox.find('ymin').text) + x2 = float(bbox.find('xmax').text) + y2 = float(bbox.find('ymax').text) + cls = self._class_to_ind[obj.find('name').text.strip()] + boxes[ix, :] = [x1, y1, x2, y2] + gt_classes[ix] = cls + overlaps[ix, cls] = 1.0 + seg_areas[ix] = (x2 - x1 + 1) * (y2 - y1 + 1) + + overlaps = scipy.sparse.csr_matrix(overlaps) + + return {'boxes' : boxes, + 'gt_classes': gt_classes, + 'gt_overlaps' : overlaps, + 'flipped' : False, + 'seg_areas' : seg_areas} + + def _get_comp_id(self): + comp_id = (self._comp_id + '_' + self._salt if self.config['use_salt'] + else self._comp_id) + return comp_id + + def _get_voc_results_file_template(self): + # VOCdevkit/results/VOC2007/Main/_det_test_aeroplane.txt + filename = self._get_comp_id() + '_det_' + self._image_set + '_{:s}.txt' + path = os.path.join( + self._devkit_path, + 'results', + 'font_dataset', + 'Main', + filename) + return path + + def _write_fontdataset_results_file(self, all_boxes): + for cls_ind, cls in enumerate(self.classes): + if cls == '__background__': + continue + print 'Writing {} fontdataset results file'.format(str(cls.encode('utf-8'))) + # filename = self._get_voc_results_file_template().format(str(cls.encode('utf-8'))) + filename = os.path.join(self._devkit_path, str(cls_ind) + '.txt') + with open(filename, 'wt') as f: + for im_ind, index in enumerate(self.image_index): + dets = all_boxes[cls_ind][im_ind] + if dets == []: + continue + # the VOCdevkit expects 1-based indices + for k in xrange(dets.shape[0]): + f.write('{:s} {:.3f} {:.1f} {:.1f} {:.1f} {:.1f}\n'. + format(index, dets[k, -1], + dets[k, 0] + 1, dets[k, 1] + 1, + dets[k, 2] + 1, dets[k, 3] + 1)) + + def _do_python_eval(self, output_dir='output'): + annopath = os.path.join( + self._devkit_path, + 'fontdataset', + 'annotations', + '{:s}.xml') + imagesetfile = os.path.join( + self._devkit_path, + 'fontdataset', + self._image_set + '.txt') + cachedir = os.path.join(self._devkit_path, 'annotations_cache') + aps = [] + # The PASCAL VOC metric changed in 2010 + use_07_metric = True if int(self._year) < 2010 else False + print 'VOC07 metric? ' + ('Yes' if use_07_metric else 'No') + if not os.path.isdir(output_dir): + os.mkdir(output_dir) + for i, cls in enumerate(self._classes): + if cls == '__background__': + continue + # filename = self._get_voc_results_file_template().format(cls) + filename = os.path.join(self._devkit_path, str(i) + '.txt') + rec, prec, ap = fontdataset_eval( + filename, annopath, imagesetfile, cls, cachedir, ovthresh=0.5, + use_07_metric=use_07_metric) + aps += [ap] + # print('AP for {} = {:.4f}'.format(cls, ap)) + print('AP for {} = {:.4f}'.format(str(cls.encode('utf-8')), ap)) + print('AP for {} = ap:{} rec:{} prec:{}'.format(str(cls.encode('utf-8')), ap, rec, prec)) + # with open(os.path.join(output_dir, cls + '_pr.pkl'), 'w') as f: + with open(os.path.join(output_dir, str(i) + '_pr.pkl'), 'w') as f: + cPickle.dump({'rec': rec, 'prec': prec, 'ap': ap}, f) + + print('Mean AP = {:.4f}'.format(np.mean(aps))) + print('~~~~~~~~') + print('Results:') + for ap in aps: + print('{:.3f}'.format(ap)) + print('Mean AP = {:.3f}'.format(np.mean(aps))) + print('~~~~~~~~') + print('') + print('--------------------------------------------------------------') + print('Results computed with the **unofficial** Python eval code.') + print('Results should be very close to the official MATLAB eval code.') + print('Recompute with `./tools/reval.py --matlab ...` for your paper.') + print('-- Thanks, The Management') + print('--------------------------------------------------------------') + + def _do_matlab_eval(self, output_dir='output'): + print '-----------------------------------------------------' + print 'Computing results with the official MATLAB eval code.' + print '-----------------------------------------------------' + path = os.path.join(cfg.ROOT_DIR, 'lib', 'datasets', + 'VOCdevkit-matlab-wrapper') + cmd = 'cd {} && '.format(path) + cmd += '{:s} -nodisplay -nodesktop '.format(cfg.MATLAB) + cmd += '-r "dbstop if error; ' + cmd += 'voc_eval(\'{:s}\',\'{:s}\',\'{:s}\',\'{:s}\'); quit;"' \ + .format(self._devkit_path, self._get_comp_id(), + self._image_set, output_dir) + print('Running:\n{}'.format(cmd)) + status = subprocess.call(cmd, shell=True) + + def evaluate_detections(self, all_boxes, output_dir): + self._write_fontdataset_results_file(all_boxes) + self._do_python_eval(output_dir) + if self.config['matlab_eval']: + self._do_matlab_eval(output_dir) + if self.config['cleanup']: + for i, cls in enumerate(self._classes): + if cls == '__background__': + continue + # filename = self._get_voc_results_file_template().format(cls) + filename = os.path.join(self._devkit_path, str(i) + '.txt') + os.remove(filename) + + def competition_mode(self, on): + if on: + self.config['use_salt'] = False + self.config['cleanup'] = False + else: + self.config['use_salt'] = True + self.config['cleanup'] = True + +if __name__ == '__main__': + from datasets.phd08 import phd08 + d = phd08('trainval') + res = d.roidb + from IPython import embed; embed() diff --git a/lib/datasets/fontdataset_eval.py b/lib/datasets/fontdataset_eval.py new file mode 100644 index 00000000..75d0e58e --- /dev/null +++ b/lib/datasets/fontdataset_eval.py @@ -0,0 +1,208 @@ +# -------------------------------------------------------- +# Fast/er R-CNN +# Licensed under The MIT License [see LICENSE for details] +# Written by Bharath Hariharan +# -------------------------------------------------------- + +import xml.etree.ElementTree as ET +import os +import cPickle +import numpy as np +import pdb +def parse_rec(filename): + """ Parse a PASCAL VOC xml file """ + tree = ET.parse(filename) + objects = [] + for obj in tree.findall('object'): + obj_struct = {} + obj_struct['name'] = obj.find('name').text + # obj_struct['pose'] = obj.find('pose').text + # obj_struct['truncated'] = int(obj.find('truncated').text) + # obj_struct['difficult'] = int(obj.find('difficult').text) + obj_struct['pose'] = '' + obj_struct['truncated'] = 0 + obj_struct['difficult'] = 0 + bbox = obj.find('bndbox') + obj_struct['bbox'] = [int(bbox.find('xmin').text), + int(bbox.find('ymin').text), + int(bbox.find('xmax').text), + int(bbox.find('ymax').text)] + objects.append(obj_struct) + + return objects + +def voc_ap(rec, prec, use_07_metric=False): + """ ap = voc_ap(rec, prec, [use_07_metric]) + Compute VOC AP given precision and recall. + If use_07_metric is true, uses the + VOC 07 11 point method (default:False). + """ + if use_07_metric: + # 11 point metric + ap = 0. + for t in np.arange(0., 1.1, 0.1): + if np.sum(rec >= t) == 0: + p = 0 + else: + p = np.max(prec[rec >= t]) + ap = ap + p / 11. + else: + # correct AP calculation + # first append sentinel values at the end + mrec = np.concatenate(([0.], rec, [1.])) + mpre = np.concatenate(([0.], prec, [0.])) + + # compute the precision envelope + for i in range(mpre.size - 1, 0, -1): + mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i]) + + # to calculate area under PR curve, look for points + # where X axis (recall) changes value + i = np.where(mrec[1:] != mrec[:-1])[0] + + # and sum (\Delta recall) * prec + ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) + return ap + +def fontdataset_eval(detpath, + annopath, + imagesetfile, + classname, + cachedir, + ovthresh=0.5, + use_07_metric=False): + """rec, prec, ap = voc_eval(detpath, + annopath, + imagesetfile, + classname, + [ovthresh], + [use_07_metric]) + + Top level function that does the PASCAL VOC evaluation. + + detpath: Path to detections + detpath.format(classname) should produce the detection results file. + annopath: Path to annotations + annopath.format(imagename) should be the xml annotations file. + imagesetfile: Text file containing the list of images, one image per line. + classname: Category name (duh) + cachedir: Directory for caching the annotations + [ovthresh]: Overlap threshold (default = 0.5) + [use_07_metric]: Whether to use VOC07's 11 point AP computation + (default False) + """ + # assumes detections are in detpath.format(classname) + # assumes annotations are in annopath.format(imagename) + # assumes imagesetfile is a text file with each line an image name + # cachedir caches the annotations in a pickle file + + # first load gt + if not os.path.isdir(cachedir): + os.mkdir(cachedir) + cachefile = os.path.join(cachedir, 'annots.pkl') + # read list of images + with open(imagesetfile, 'r') as f: + lines = f.readlines() + imagenames = [x.strip() for x in lines] + + if not os.path.isfile(cachefile): + # load annots + recs = {} + for i, imagename in enumerate(imagenames): + recs[imagename] = parse_rec(annopath.format(imagename)) + if i % 100 == 0: + print 'Reading annotation for {:d}/{:d}'.format( + i + 1, len(imagenames)) + # save + print 'Saving cached annotations to {:s}'.format(cachefile) + with open(cachefile, 'w') as f: + cPickle.dump(recs, f) + else: + # load + with open(cachefile, 'r') as f: + recs = cPickle.load(f) + + # extract gt objects for this class + class_recs = {} + npos = 0 + for imagename in imagenames: + R = [obj for obj in recs[imagename] if obj['name'] == classname] + bbox = np.array([x['bbox'] for x in R]) + difficult = np.array([x['difficult'] for x in R]).astype(np.bool) + det = [False] * len(R) + npos = npos + sum(~difficult) + class_recs[imagename] = {'bbox': bbox, + 'difficult': difficult, + 'det': det} + + # read dets + detfile = detpath.format(classname) + with open(detfile, 'r') as f: + lines = f.readlines() + if any(lines) == 1: + + splitlines = [x.strip().split(' ') for x in lines] + image_ids = [x[0] for x in splitlines] + confidence = np.array([float(x[1]) for x in splitlines]) + BB = np.array([[float(z) for z in x[2:]] for x in splitlines]) + + # sort by confidence + sorted_ind = np.argsort(-confidence) + sorted_scores = np.sort(-confidence) + BB = BB[sorted_ind, :] + image_ids = [image_ids[x] for x in sorted_ind] + + # go down dets and mark TPs and FPs + nd = len(image_ids) + tp = np.zeros(nd) + fp = np.zeros(nd) + for d in range(nd): + R = class_recs[image_ids[d]] + bb = BB[d, :].astype(float) + ovmax = -np.inf + BBGT = R['bbox'].astype(float) + + if BBGT.size > 0: + # compute overlaps + # intersection + ixmin = np.maximum(BBGT[:, 0], bb[0]) + iymin = np.maximum(BBGT[:, 1], bb[1]) + ixmax = np.minimum(BBGT[:, 2], bb[2]) + iymax = np.minimum(BBGT[:, 3], bb[3]) + iw = np.maximum(ixmax - ixmin + 1., 0.) + ih = np.maximum(iymax - iymin + 1., 0.) + inters = iw * ih + + # union + uni = ((bb[2] - bb[0] + 1.) * (bb[3] - bb[1] + 1.) + + (BBGT[:, 2] - BBGT[:, 0] + 1.) * + (BBGT[:, 3] - BBGT[:, 1] + 1.) - inters) + + overlaps = inters / uni + ovmax = np.max(overlaps) + jmax = np.argmax(overlaps) + + if ovmax > ovthresh: + if not R['difficult'][jmax]: + if not R['det'][jmax]: + tp[d] = 1. + R['det'][jmax] = 1 + else: + fp[d] = 1. + else: + fp[d] = 1. + + # compute precision recall + fp = np.cumsum(fp) + tp = np.cumsum(tp) + rec = tp / float(npos) + # avoid divide by zero in case the first detection matches a difficult + # ground truth + prec = tp / np.maximum(tp + fp, np.finfo(np.float64).eps) + ap = voc_ap(rec, prec, use_07_metric) + else: + rec = -1 + prec = -1 + ap = -1 + + return rec, prec, ap \ No newline at end of file diff --git a/lib/model/train_val.py b/lib/model/train_val.py index e303bafa..f5b1749f 100644 --- a/lib/model/train_val.py +++ b/lib/model/train_val.py @@ -323,6 +323,7 @@ def train_model(self, sess, max_iters): def get_training_roidb(imdb): """Returns a roidb (Region of Interest database) for use in training.""" + cfg.TRAIN.USE_FLIPPED = False if cfg.TRAIN.USE_FLIPPED: print('Appending horizontally-flipped training examples...') imdb.append_flipped_images() diff --git a/lib/setup.py b/lib/setup.py index 84ddc8ec..27d0bada 100644 --- a/lib/setup.py +++ b/lib/setup.py @@ -127,7 +127,7 @@ def build_extensions(self): # we're only going to use certain compiler args with nvcc and not with gcc # the implementation of this trick is in customize_compiler() below extra_compile_args={'gcc': ["-Wno-unused-function"], - 'nvcc': ['-arch=sm_52', + 'nvcc': ['-arch=sm_60', '--ptxas-options=-v', '-c', '--compiler-options', diff --git a/tools/demo.py b/tools/demo.py index 288bd0b0..3cbe44e5 100755 --- a/tools/demo.py +++ b/tools/demo.py @@ -30,6 +30,9 @@ from nets.vgg16 import vgg16 from nets.resnet_v1 import resnetv1 +import matplotlib +matplotlib.use('Agg') + CLASSES = ('__background__', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', @@ -40,7 +43,7 @@ NETS = {'vgg16': ('vgg16_faster_rcnn_iter_70000.ckpt',),'res101': ('res101_faster_rcnn_iter_110000.ckpt',)} DATASETS= {'pascal_voc': ('voc_2007_trainval',),'pascal_voc_0712': ('voc_2007_trainval+voc_2012_trainval',)} -def vis_detections(im, class_name, dets, thresh=0.5): +def vis_detections(im, class_name, dets, image_name, thresh=0.5): """Draw detected bounding boxes.""" inds = np.where(dets[:, -1] >= thresh)[0] if len(inds) == 0: @@ -48,6 +51,7 @@ def vis_detections(im, class_name, dets, thresh=0.5): im = im[:, :, (2, 1, 0)] fig, ax = plt.subplots(figsize=(12, 12)) + fig.savefig(image_name + '.result.png') ax.imshow(im, aspect='equal') for i in inds: bbox = dets[i, :4] @@ -97,7 +101,7 @@ def demo(sess, net, image_name): cls_scores[:, np.newaxis])).astype(np.float32) keep = nms(dets, NMS_THRESH) dets = dets[keep, :] - vis_detections(im, cls, dets, thresh=CONF_THRESH) + vis_detections(im, cls, dets, image_name, thresh=CONF_THRESH) def parse_args(): """Parse input arguments.""" @@ -152,4 +156,4 @@ def parse_args(): print('Demo for data/demo/{}'.format(im_name)) demo(sess, net, im_name) - plt.show() + # plt.show() diff --git a/tools/demo_fontdataset.py b/tools/demo_fontdataset.py new file mode 100644 index 00000000..569ed416 --- /dev/null +++ b/tools/demo_fontdataset.py @@ -0,0 +1,404 @@ +#!/usr/bin/env python + +# -------------------------------------------------------- +# Tensorflow Faster R-CNN +# Licensed under The MIT License [see LICENSE for details] +# Written by Xinlei Chen, based on code from Ross Girshick +# -------------------------------------------------------- + +""" +Demo script showing detections in sample images. + +See README.md for installation instructions before running. +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import _init_paths +from model.config import cfg, cfg_from_list +from model.test import im_detect +from model.nms_wrapper import nms + +from utils.timer import Timer +import tensorflow as tf +import matplotlib.pyplot as plt +import numpy as np +import os, cv2 +import argparse + +from copy import deepcopy +import xml.etree.ElementTree as ET + +from nets.vgg16 import vgg16 +from nets.resnet_v1 import resnetv1 + +from datasets.factory import get_imdb + +from PIL import Image, ImageFont, ImageDraw + +NETS = {'vgg16': ('vgg16_faster_rcnn_iter_70000.ckpt',),'res101': ('res101_faster_rcnn_iter_110000.ckpt',)} +DATASETS= { + 'fontdataset': ('fontdataset_trainval',), + 'fontdataset_test': ('fontdataset_test',), + 'pascal_voc': ('voc_2007_trainval',), + 'pascal_voc_0712': ('voc_2007_trainval+voc_2012_trainval',) +} + +def parse_rec(filename): + """ Parse a PASCAL VOC xml file """ + tree = ET.parse(filename) + objects = [] + for obj in tree.findall('object'): + obj_struct = dict() + obj_struct['name'] = obj.find('name').text + bbox = obj.find('bndbox') + obj_struct['bbox'] = [int(bbox.find('xmin').text), + int(bbox.find('ymin').text), + int(bbox.find('xmax').text), + int(bbox.find('ymax').text)] + objects.append(obj_struct) + return objects + +def vis_detections(pil_im, class_name, dets, thresh=0.5): + """Draw detected bounding boxes.""" + inds = np.where(dets[:, -1] >= thresh)[0] + if len(inds) == 0: + return list() + + boxes = list() + draw = ImageDraw.Draw(pil_im) + #font = ImageFont.truetype(os.path.join(cfg.DATA_DIR, 'Ubuntu.ttf'), 14) + font = ImageFont.truetype('/usr/share/fonts/truetype/nanum/NanumGothic_Coding.ttf', 14) + for i in inds: + bbox = dets[i, :4] + score = dets[i, -1] + + draw.rectangle([(int(bbox[0]), int(bbox[1])), (int(bbox[2]), int(bbox[3]))], + fill=None, + outline=(255, 0, 0) + ) + draw.text((int(bbox[0]) - 2, int(bbox[1]) - 15), + # '{:s} {:.3f}'.format(str(class_name.encode('utf-8')), score), + # class_name + u' ' + unicode(str(score)), + class_name + u' ' + u'{:.2f}'.format(score), + font=font, + fill=(0, 0, 255), + encoding='utf-8' + ) + boxes.append([ + int(bbox[0]), int(bbox[1]), int(bbox[2]), int(bbox[3]), score, class_name + ]) + + del draw + return boxes + + +def calc_fontsize(bbox): + size_remap = {0: 0, 10: 10, 20: 20, 30: 30, 40: 50, 50: 50, 60: 50, 80: 100, 90: 100, 100: 100, 110: 100, 120: 100} + fontsize = np.maximum(bbox[2] - bbox[0], bbox[3] - bbox[1]) + fontsize = int(fontsize / 10) * 10 + fontsize = 100 if fontsize > 100 else fontsize + return size_remap[fontsize] if fontsize in size_remap else fontsize + +def compare_founding(found_boxes, answer, ovthresh=0.5): + ''' + + :param found_boxes: + :param answer: + :return: + ''' + answer_fontsize = [calc_fontsize(ans['bbox']) for ans in answer] + + num_answer = len(answer) + num_matching = 0 + num_non_matching = 0 + + num_answer_fontsize = dict() + num_answer_char = dict() + + for fontsize, ans in zip(answer_fontsize, answer): + if ans['name'] not in num_answer_char: + num_answer_char[ans['name']] = 1 + else: + num_answer_char[ans['name']] += 1 + if fontsize not in num_answer_fontsize: + num_answer_fontsize[fontsize] = 1 + else: + num_answer_fontsize[fontsize] += 1 + + char_count = dict() + fontsize_count = dict() + char_non_matching_count = dict() + fontsize_non_matching_count = dict() + + bbox = np.array([x['bbox'] for x in answer]) + + for found in found_boxes: + ovmax = -np.inf + BBGT = bbox.astype(float) + + if BBGT.size > 0: + # compute overlaps + # intersection + ixmin = np.maximum(BBGT[:, 0], found[0]) + iymin = np.maximum(BBGT[:, 1], found[1]) + ixmax = np.minimum(BBGT[:, 2], found[2]) + iymax = np.minimum(BBGT[:, 3], found[3]) + iw = np.maximum(ixmax - ixmin + 1., 0.) + ih = np.maximum(iymax - iymin + 1., 0.) + inters = iw * ih + + # union + uni = ((found[2] - found[0] + 1.) * (found[3] - found[1] + 1.) + + (BBGT[:, 2] - BBGT[:, 0] + 1.) * + (BBGT[:, 3] - BBGT[:, 1] + 1.) - inters) + + overlaps = inters / uni + ovmax = np.max(overlaps) + jmax = np.argmax(overlaps) + + if ovmax > ovthresh: + fontsize = answer_fontsize[jmax] + label = answer[jmax]['name'] + if label == found[5]: + if label not in char_count: + char_count[label] = 1 + else: + char_count[label] += 1 + if fontsize not in fontsize_count: + fontsize_count[fontsize] = 1 + else: + fontsize_count[fontsize] += 1 + num_matching += 1 + else: + if found[5] not in char_non_matching_count: + char_non_matching_count[found[5]] = 1 + else: + char_non_matching_count[found[5]] += 1 + fontsize = calc_fontsize(found) + if fontsize not in fontsize_non_matching_count: + fontsize_non_matching_count[fontsize] = 1 + else: + fontsize_non_matching_count[fontsize] += 1 + num_non_matching += 1 + else: + if found[5] not in char_non_matching_count: + char_non_matching_count[found[5]] = 1 + else: + char_non_matching_count[found[5]] += 1 + fontsize = calc_fontsize(found) + if fontsize not in fontsize_non_matching_count: + fontsize_non_matching_count[fontsize] = 1 + else: + fontsize_non_matching_count[fontsize] += 1 + num_non_matching += 1 + + return num_matching, num_answer, num_non_matching, fontsize_count, char_count, num_answer_fontsize, num_answer_char, char_non_matching_count, fontsize_non_matching_count + + + +def demo(sess, net, image_name, imdb, testimg): + """Detect object classes in an image using pre-computed object proposals.""" + + # Load the demo image + # im_file = os.path.join(cfg.DATA_DIR, 'demo', image_name) + im_file = os.path.join(testimg, 'images', image_name) + anno_file = os.path.join(testimg, 'annotations', image_name.split('.')[0] + '.xml') + print(im_file, anno_file) + im = cv2.imread(im_file) + + # Detect all object classes and regress object bounds + timer = Timer() + timer.tic() + scores, boxes = im_detect(sess, net, im) + timer.toc() + print('Detection took {:.3f}s for {:d} object proposals'.format(timer.total_time, boxes.shape[0])) + + pil_im = Image.open(im_file) + + CLASSES = imdb.classes + # Visualize detections for each class + CONF_THRESH = 0.8 + NMS_THRESH = 0.3 + + found_boxes = list() + + for cls_ind, cls in enumerate(CLASSES[1:]): + cls_ind += 1 # because we skipped background + cls_boxes = boxes[:, 4*cls_ind:4*(cls_ind + 1)] + cls_scores = scores[:, cls_ind] + dets = np.hstack((cls_boxes, + cls_scores[:, np.newaxis])).astype(np.float32) + keep = nms(dets, NMS_THRESH) + dets = dets[keep, :] + found_boxes += vis_detections(pil_im, cls, dets, thresh=CONF_THRESH) + + result_file = os.path.join(testimg, 'result', image_name.split('.')[0] + '_result.jpg') + os.makedirs(result_file) + pil_im.save(result_file) + + answer = parse_rec(anno_file) + num_matching, num_answer, num_non_matching, fontsize_count, char_count, num_answer_fontsize, num_answer_char, char_non_matching_count, fontsize_non_matching_count = compare_founding(found_boxes, answer) + return num_matching, num_answer, num_non_matching, fontsize_count, char_count, num_answer_fontsize, num_answer_char, char_non_matching_count, fontsize_non_matching_count + +def parse_args(): + """Parse input arguments.""" + parser = argparse.ArgumentParser(description='Tensorflow Faster R-CNN demo') + parser.add_argument('--net', dest='demo_net', help='Network to use [vgg16 res101]', + choices=NETS.keys(), default='res101') + parser.add_argument('--dataset', dest='dataset', help='Trained dataset [pascal_voc pascal_voc_0712]', + choices=DATASETS.keys(), default='pascal_voc_0712') + parser.add_argument('--index', dest='index', help='Index list file name', + default=' ') + parser.add_argument('--testimg', dest='testimg', help='Testing images: foler names', + default='demo') + parser.add_argument('--model', dest='model', help='Trained model file name', + default=' ') + + parser.add_argument('--set', dest='set_cfgs', + help='set config keys', default=None, + nargs=argparse.REMAINDER) + + args = parser.parse_args() + + return args + +def merge_dict(a, b): + for k, v in b.items(): + if k in a: + a[k] += v + else: + a[k] = v + return a + +if __name__ == '__main__': + cfg.TEST.HAS_RPN = True # Use RPN for proposals + args = parse_args() + + if args.set_cfgs is not None: + cfg_from_list(args.set_cfgs) + + # model path + demonet = args.demo_net + dataset = args.dataset + tfmodel = args.model + index_file = args.index + + testimg = args.testimg + print(testimg) + + if not os.path.isfile(tfmodel + '.meta'): + raise IOError(('{:s} not found.\nDid you download the proper networks from ' + 'our server and place them properly?').format(tfmodel + '.meta')) + + # set config + tfconfig = tf.ConfigProto(allow_soft_placement=True) + tfconfig.gpu_options.allow_growth=True + + # init session + sess = tf.Session(config=tfconfig) + + # load dataset + imdb = get_imdb(dataset) + + # load network + if demonet == 'vgg16': + net = vgg16() + elif demonet == 'res101': + net = resnetv1(num_layers=101) + else: + raise NotImplementedError + net.create_architecture("TEST", imdb.num_classes, + tag='default', anchor_scales=[2,3,4,5,6,8, 16, 32]) + saver = tf.train.Saver() + saver.restore(sess, tfmodel) + + print('Loaded network {:s}'.format(tfmodel)) + + # im_names = ['000456.jpg', '000542.jpg', '001150.jpg', + # '001763.jpg', '004545.jpg'] + + #im_names = [str(x) + '.png' for x in range(50)] + # im_names = [str(x) + '.png' for x in range(8000)] + + with open(index_file, 'r') as f: + lines = f.readlines() + im_names = [x.strip() + '.jpg' for x in lines] + + num_matching_sum = 0 + num_answer_sum = 0 + num_non_matching_sum = 0 + fontsize_count_sum = dict() + char_count_sum = dict() + + fontsize_non_matching_count_sum = dict() + char_count_non_matching_sum = dict() + + num_answer_fontsize_sum = dict() + num_answer_char_sum = dict() + + for idx, im_name in enumerate(im_names): + print('~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~') + print('Demo for {}/{}'.format(testimg, im_name)) + num_matching, num_answer, num_non_matching, fontsize_count, char_count, num_answer_fontsize, num_answer_char, char_non_matching_count, fontsize_non_matching_count = demo(sess, net, im_name, imdb, testimg) + + num_matching_sum += num_matching + num_answer_sum += num_answer + num_non_matching_sum += num_non_matching + merge_dict(fontsize_count_sum, fontsize_count) + merge_dict(char_count_sum, char_count) + + merge_dict(fontsize_non_matching_count_sum, fontsize_non_matching_count) + merge_dict(char_count_non_matching_sum, char_non_matching_count) + + merge_dict(num_answer_fontsize_sum, num_answer_fontsize) + merge_dict(num_answer_char_sum, num_answer_char) + #if idx == 3: break + + recall = num_matching_sum / float(num_answer_sum) + precision = num_matching_sum / float(num_matching_sum + num_non_matching_sum) + print(num_matching_sum, num_answer_sum, 'recall', recall, 'precision', precision) + print(sum(fontsize_count_sum.values()), sum(char_count_sum.values()), sum(num_answer_fontsize_sum.values()), sum(num_answer_char_sum.values())) + + # calculate positive + negative prediction: to calculate precision + fontsize_count_pos_neg_sum = deepcopy(fontsize_count_sum) + char_count_pos_neg_sum = deepcopy(char_count_sum) + merge_dict(fontsize_count_pos_neg_sum, fontsize_non_matching_count_sum) + merge_dict(char_count_pos_neg_sum, char_count_non_matching_sum) + + merge_dict(fontsize_non_matching_count_sum, fontsize_non_matching_count) + merge_dict(char_count_non_matching_sum, char_non_matching_count) + + print('char, recall') + for char, count in num_answer_char_sum.items(): + if char in char_count_sum: + recall = float(char_count_sum[char]) / float(count) + else: + recall = 0 + print(char.encode('utf-8').strip(), recall) + + print('char, precision') + for char, count in char_count_pos_neg_sum.items(): + if char in char_count_sum: + recall = float(char_count_sum[char]) / float(count) + else: + recall = 0 + print(char.encode('utf-8').strip(), recall) + + print('fontsize, recall') + for fontsize, count in num_answer_fontsize_sum.items(): + if fontsize in fontsize_count_sum: + recall = float(fontsize_count_sum[fontsize]) / float(count) + else: + recall = 0 + print(fontsize, recall) + + print('fontsize, precision') + for fontsize, count in fontsize_count_pos_neg_sum.items(): + if fontsize in fontsize_count_sum: + recall = float(fontsize_count_sum[fontsize]) / float(count) + else: + recall = 0 + print(fontsize, recall) + diff --git a/tools/fontdataset_synth_image.py b/tools/fontdataset_synth_image.py new file mode 100644 index 00000000..9469d25c --- /dev/null +++ b/tools/fontdataset_synth_image.py @@ -0,0 +1,431 @@ +''' +# Generate image with random string & font & size +1. select font +1. select size +1. select chars 2 ~ 10 +1. pick random position +1. check fill_image_ratio + +''' + +import matplotlib.pyplot as plt +import numpy as np +from PIL import Image, ImageFont, ImageDraw + +from xml.etree.ElementTree import Element, SubElement, ElementTree, dump + +import random +import copy +from string import ascii_lowercase, ascii_uppercase +from os import listdir, path, makedirs +from os.path import isfile, isdir, join, basename, exists + +import os +import math +begin = 0xac00 +end = 0xd7a3 + +POINT_PER_CELL = 10 + +NUM_OF_NONE = 0 + +class Sample_IMG(): + def __init__(self, fill_ratio=0.3, width=800, height=600): + self.fill_ratio = fill_ratio + self.width = width + self.height = height + self.fill_area = 0 + + self.chars = list() + self.char_positions = list() + self.boxes = list() + + ''' + # store boxes by descending order + 1. begin with ((0, 0), (N, M)) + 2. if place box in ((i, j), (n, m)) where (i, j > 0 and n < N, m < M and i < n, j < m) + 3. split remaining areas into 4 boxes: ((0, 0), (N, j)), ((0, j), (i, m)), ((n, j), (N, m)), ((0, m), (N, M)) + ''' + self.N = int(self.width / POINT_PER_CELL) + self.M = int(self.height / POINT_PER_CELL) + self.areas = [(0, 0, self.N, self.M)] + + def hasFill(self): + if self.fill_ratio * self.width * self.height < self.fill_area: + return True + else: + return False + + def appendChars(self, chDataset, font, size, chars): + imFont = chDataset.getIMFont(font, size) + char_sizes = [imFont.getsize(ch) for ch in chars] + sum_area = sum([ch_size[0] * ch_size[1] for ch_size in char_sizes]) + + self.chars.append((font, size, chars, char_sizes)) + self.fill_area += sum_area + + def assignBox(self, char_width, char_height): + ''' + + 1. get one box from self.area + 2. check is it fit to box + + :param char_width: + :param char_height: + :return: + ''' + char_width_n = math.ceil(char_width / POINT_PER_CELL) + char_height_m = math.ceil(char_height / POINT_PER_CELL) + + assigned_box = None + + new_boxes = [] + for idx, box in enumerate(self.areas): + zero_x = box[0] + zero_y = box[1] + N = box[2] + M = box[3] + box_width = box[2] - box[0] + box_height = box[3] - box[1] + # char_box not fit into box, skip it + if box_width < char_width_n or box_height < char_height_m: + continue + + if box_width == char_width_n: + i = zero_x + else: + i = zero_x + random.randrange(box_width - char_width_n) + + if box_height == char_height_m: + j = zero_y + else: + j = zero_y + random.randrange(box_height - char_height_m) + + # x_offset = random.randrange(POINT_PER_CELL) + # y_offset = random.randrange(POINT_PER_CELL) + x_offset = 0 + y_offset = 0 + + n = i + char_width_n + m = j + char_height_m + + assigned_box = (i * POINT_PER_CELL + x_offset, j * POINT_PER_CELL + y_offset, n * POINT_PER_CELL , m * POINT_PER_CELL) + + new_boxes = [box for box in [(zero_x, zero_y, N, j), + (zero_x, j, i, m), + (n, j, N, m), + (zero_x, m, N, M)] if box[0] < box[2] and box[1] < box[3]] + break + + self.areas = self.areas[0:idx] + self.areas[idx+1:] + new_boxes + + return assigned_box + + def placeStoreChars(self): + # sort by area of chars, descending order + self.chars.sort(key=lambda x: x[1], reverse=True) + + for char_info in self.chars: + font = char_info[0] + font_size = char_info[1] + chars = char_info[2] + char_sizes = char_info[3] + char_width = sum([ch_size[0] for ch_size in char_sizes]) + char_height = max([ch_size[1] for ch_size in char_sizes]) + + box = self.assignBox(char_width, char_height) + if box is None: + x = random.randrange(self.width - char_width - 1) + y = random.randrange(self.height - char_height - 1) + box = (x, y, x + char_width, y + char_height) + + # skip invalid placement + if box[2] >= self.width or box[3] >= self.height: + continue + + self.boxes.append(box) + + + def makeXML(self, char_list, img_filename, idx=None, output_path=None): + # char_info: (x, y, font_size, font_size) + + def indent(elem, level=0): + i = "\n" + level * " " + if len(elem): + if not elem.text or not elem.text.strip(): + elem.text = i + " " + if not elem.tail or not elem.tail.strip(): + elem.tail = i + for elem in elem: + indent(elem, level + 1) + if not elem.tail or not elem.tail.strip(): + elem.tail = i + else: + if level and (not elem.tail or not elem.tail.strip()): + elem.tail = i + + annotation = Element("annotation") + SubElement(annotation, "folder").text = output_path.split('/')[-1] + SubElement(annotation, "filename").text = str(img_filename.split('/')[-1]) + size = Element("size") + annotation.append(size) + SubElement(size, "width").text = str(self.width) + SubElement(size, "height").text = str(self.height) + SubElement(size, "depth").text = '3' + + for ch_info in char_list: + obj = Element("object") + annotation.append(obj) + SubElement(obj, "name").text = ch_info[1] + bndbox = Element("bndbox") + obj.append(bndbox) + + x, y, width, height = ch_info[0] + + SubElement(bndbox, "xmin").text = str(x) + SubElement(bndbox, "ymin").text = str(y) + SubElement(bndbox, "xmax").text = str(width) + SubElement(bndbox, "ymax").text = str(height) + + indent(annotation) + ElementTree(annotation).write(os.path.join(output_path, 'annotations', '%d.xml' % idx), encoding='utf-8') + + def prn(self, chDataset=None, idx=None, output_path=None): + ''' + + :param chDataset: + :param filename: + :return: + ''' + # print('fill_area', self.fill_area, 'width', self.width, 'height', self.height) + # print('chars', [ch for ch in self.chars]) + # print('# of chars', sum([len(ch[2]) for ch in self.chars])) + + char_list = [] + + # base_font = ImageFont.truetype('/Library/Fonts/AppleGothic.ttf', 10) + if chDataset is not None: + im = Image.new('RGB', (800, 600), color=(256, 256, 256)) + draw = ImageDraw.Draw(im) + for box, char_info in zip(self.boxes, self.chars): + font = char_info[0] + font_size = char_info[1] + chars = char_info[2] + + self.boxes.append(box) + imFont = chDataset.getIMFont(font, font_size) + draw.text(box[:2], ''.join(chars), font=imFont, fill=(0, 0, 0)) + + # label_box = [box[0], box[1] - 10] + # draw.text(label_box, ''.join(chars) + font.split('/')[-1], font=base_font, fill=(256, 0, 0)) + # draw.rectangle(box, outline=(256, 0, 0)) + + width_offset = 0 + for ch in chars: + ch_width, ch_height = imFont.getsize(ch) + ch_box = (box[0] + width_offset, box[1], box[0] + width_offset + ch_width, box[1] + ch_height) + # draw.rectangle(ch_box, outline=(0, 256, 0)) + width_offset += ch_width + char_list.append((ch_box, ch)) + + if idx is None or output_path is None: + import uuid + filename = '%s.png' % uuid.uuid4() + else: + filename = join(output_path, 'images', '%d.png' % idx) + im.save(filename) + + self.makeXML(char_list, filename, idx=idx, output_path=output_path) + + return sum([len(ch[2]) for ch in self.chars]) + + @classmethod + def generate(cls, chDataset, n_char=10, fill_ratio=0.2, width=800, height=600): + ''' + generate single image sample containing multiple chars + + :param n_char: number of char per font & size + :param ratio: fill area ratio of image with character region + :param width: width of image + :param height: height of image + :return: + ''' + if chDataset.isValid() is False: + return None + + sampleImg = Sample_IMG(fill_ratio=fill_ratio, width=width, height=height) + + while sampleImg.hasFill() is False and chDataset.hasMoreData() is True: + tuple_idx, chars = chDataset.getChars(n_char=n_char) + if len(chars) > 0: + (font, size) = chDataset.font_size_tuples[tuple_idx] + sampleImg.appendChars(chDataset, font, size, chars) + + sampleImg.placeStoreChars() + return sampleImg + +class CH_Dataset(): + def __init__(self, font_path=None): + self.char_list = [chr(begin + idx) for idx in range(end - begin + 1)] \ + + [x for x in (ascii_lowercase + ascii_uppercase)] + [str(x) for x in range(10)] + [x for x in + '~!@#$%^&*()_+-=<>?,.;:[]{}|'] + + #self.char_list = self.char_list[:1000] + self.font_list = [join(font_path, f) for f in listdir(font_path) if + isfile(join(font_path, f)) and f.find('.DS_Store') == -1] + + self.font_sizes = [10] * 5 + [20] * 10 + [30] * 7 + [50] * 5 + [100] * 2 + + print('total # of chars', len(self.char_list) * len(self.font_list) * len(self.font_sizes)) + + self.counter = 0 + self.gen_counter = 0 + self.fill_ratio = 0 + + self.font_size_tuples = None + self.char_list_list = None + self.counter_list = None + self.filled_counter = None + self.filled_counter_num = 0 + + self.imFont = dict() + + def isStop(self): + self.counter = self.counter + 1 + if self.counter > 50: + return True + else: + return False + + def getFont(self): + if self.font_list is not None and len(self.font_list) > 0 and len(self.font_sizes) > 0: + self.gen_counter += 1 + return self.font_list[self.gen_counter % len(self.font_list)], self.font_sizes[ + self.gen_counter % len(self.font_sizes)] + else: + return None, None + + def isValid(self): + if self.font_size_tuples is None or self.char_list_list is None or self.counter_list is None: + return False + else: + return True + + def hasMoreData(self): + if self.filled_counter_num >= len(self.counter_list): + return False + else: + return True + + def getIMFont(self, font, size): + if font not in self.imFont: + self.imFont[font] = dict() + + if size not in self.imFont[font]: + self.imFont[font][size] = ImageFont.truetype(font=font, size=size) + return self.imFont[font][size] + + def getChars(self, n_char=10): + # pick a random char list + idx = random.randrange(len(self.char_list_list)) + + # split first k chars(1~10) + k_chars = random.randrange(n_char) + 1 + + # get fist k chars + current_char_idx = self.counter_list[idx] + + # make sure target_idx <= len(self.char_list_list[idx]) + target_idx = current_char_idx + k_chars if current_char_idx + k_chars < len(self.char_list_list[idx]) else len( + self.char_list_list[idx]) + + # pick chars (current_char_idx:target_idx) + self.counter_list[idx] = target_idx + + if target_idx >= len(self.char_list_list[idx]): + self.filled_counter[idx] = True + self.filled_counter_num = sum(self.filled_counter) + return idx, self.char_list_list[idx][current_char_idx:target_idx] + + def generateSamples(self, n_char=7, fill_ratio=0.2, width=800, height=600, output_path=None): + ''' + generate samples: + 1. generate (font, size) tuple + 2. create N char list(N = # of (font, size) tuple) + 3. shuffle each char list + 4. create + + :param n_char: number of char per sentence + :param ratio: fill area ratio of image with character region + :param width: width of image + :param height: height of image + + :return: + ''' + self.font_size_tuples = [(f, s) for f in self.font_list for s in self.font_sizes] + self.char_list_list = [copy.deepcopy(self.char_list) for _ in range(len(self.font_size_tuples))] + [random.shuffle(ch_list) for ch_list in self.char_list_list] + self.counter_list = [0] * len(self.char_list_list) + + self.filled_counter = [counter >= len(char_list) for counter, char_list in zip(self.counter_list, self.char_list_list)] + + total_chars = 0 + results = list() + + if output_path is not None: + if not os.path.exists(output_path): + os.makedirs(output_path) + if not os.path.exists(join(output_path, 'images')): + os.makedirs(join(output_path, 'images')) + if not os.path.exists(join(output_path, 'annotations')): + os.makedirs(join(output_path, 'annotations')) + + output_idx = 0 + while self.hasMoreData() is True: + gen_img = Sample_IMG.generate(self, n_char=n_char, fill_ratio=fill_ratio, width=width, height=height) + if gen_img: + total_chars += gen_img.prn(chDataset=self, idx=output_idx, output_path=output_path) + results.append(gen_img) + output_idx += 1 + + NUMBER_OF_IMAGES = output_idx + # split train/val/test set + shuffled_index = list(range(NUMBER_OF_IMAGES)) + random.shuffle(shuffled_index) + + # splitting train/validation/test set (unit: %) + TRAIN_SET = 80 + VALID_SET = 10 + TEST_SET = 10 + num_train = int(NUMBER_OF_IMAGES * TRAIN_SET / (TRAIN_SET + VALID_SET + TEST_SET)) + num_valid = int(NUMBER_OF_IMAGES * VALID_SET / (TRAIN_SET + VALID_SET + TEST_SET)) + num_test = NUMBER_OF_IMAGES - num_train - num_valid + + with open(join(output_path, 'train.txt'), "w") as wf: + for index in shuffled_index[0:num_train]: + wf.write(str(index) + '\n') + + with open(join(output_path, 'val.txt'), "w") as wf: + for index in shuffled_index[num_train:num_train + num_valid]: + wf.write(str(index) + '\n') + + with open(join(output_path, 'trainval.txt'), "w") as wf: + for index in shuffled_index[0:num_train + num_valid]: + wf.write(str(index) + '\n') + + with open(join(output_path, 'test.txt'), "w") as wf: + for index in shuffled_index[num_train + num_valid:]: + wf.write(str(index) + '\n') + + with open(join(output_path, 'labels.txt'), "w") as wf: + for label in self.char_list: + wf.write(str(label) + '\n') + + print("Train / Valid / Test : {} / {} / {}".format(num_train, num_valid, num_test)) + print("Output path: {}".format(output_path)) + + return results + + +chd = CH_Dataset(font_path='fonts') +chd.generateSamples(output_path='data/fontdataset') diff --git a/tools/fontdataset_synth_image_with_bg.py b/tools/fontdataset_synth_image_with_bg.py new file mode 100644 index 00000000..05e889b5 --- /dev/null +++ b/tools/fontdataset_synth_image_with_bg.py @@ -0,0 +1,455 @@ +# -*- coding: utf-8 -*- + +''' +# Generate image with random string & font & size +1. select font +1. select size +1. select chars 2 ~ 10 +1. pick random position +1. check fill_image_ratio + +''' + +import matplotlib.pyplot as plt +import numpy as np +from PIL import Image, ImageFont, ImageDraw + +from xml.etree.ElementTree import Element, SubElement, ElementTree, dump + +import random +import copy +from string import ascii_lowercase, ascii_uppercase +from os import listdir, path, makedirs +from os.path import isfile, isdir, join, basename, exists +import shutil +import os +import math +begin = 0xac00 +end = 0xd7a3 + +POINT_PER_CELL = 10 + +NUM_OF_NONE = 0 + +class Sample_IMG(): + def __init__(self, fill_ratio=0.3, width=800, height=600): + self.fill_ratio = fill_ratio + self.width = width + self.height = height + self.fill_area = 0 + + self.chars = list() + self.char_positions = list() + self.boxes = list() + + ''' + # store boxes by descending order + 1. begin with ((0, 0), (N, M)) + 2. if place box in ((i, j), (n, m)) where (i, j > 0 and n < N, m < M and i < n, j < m) + 3. split remaining areas into 4 boxes: ((0, 0), (N, j)), ((0, j), (i, m)), ((n, j), (N, m)), ((0, m), (N, M)) + ''' + self.N = int(self.width / POINT_PER_CELL) + self.M = int(self.height / POINT_PER_CELL) + self.areas = [(0, 0, self.N, self.M)] + + def hasFill(self): + if self.fill_ratio * self.width * self.height < self.fill_area: + return True + else: + return False + + def appendChars(self, chDataset, font, size, chars): + imFont = chDataset.getIMFont(font, size) + char_sizes = [imFont.getsize(ch) for ch in chars] + sum_area = sum([ch_size[0] * ch_size[1] for ch_size in char_sizes]) + + self.chars.append((font, size, chars, char_sizes)) + self.fill_area += sum_area + + def assignBox(self, char_width, char_height): + ''' + + 1. get one box from self.area + 2. check is it fit to box + + :param char_width: + :param char_height: + :return: + ''' + char_width_n = math.ceil(char_width / POINT_PER_CELL) + char_height_m = math.ceil(char_height / POINT_PER_CELL) + + assigned_box = None + + new_boxes = [] + for idx, box in enumerate(self.areas): + zero_x = box[0] + zero_y = box[1] + N = box[2] + M = box[3] + box_width = box[2] - box[0] + box_height = box[3] - box[1] + # char_box not fit into box, skip it + if box_width < char_width_n or box_height < char_height_m: + continue + + if box_width == char_width_n: + i = zero_x + else: + i = zero_x + random.randrange(box_width - char_width_n) + + if box_height == char_height_m: + j = zero_y + else: + j = zero_y + random.randrange(box_height - char_height_m) + + # x_offset = random.randrange(POINT_PER_CELL) + # y_offset = random.randrange(POINT_PER_CELL) + x_offset = 0 + y_offset = 0 + + n = i + char_width_n + m = j + char_height_m + + assigned_box = (i * POINT_PER_CELL + x_offset, j * POINT_PER_CELL + y_offset, n * POINT_PER_CELL , m * POINT_PER_CELL) + + new_boxes = [box for box in [(zero_x, zero_y, N, j), + (zero_x, j, i, m), + (n, j, N, m), + (zero_x, m, N, M)] if box[0] < box[2] and box[1] < box[3]] + break + + self.areas = self.areas[0:idx] + self.areas[idx+1:] + new_boxes + + return assigned_box + + def placeStoreChars(self): + # sort by area of chars, descending order + self.chars.sort(key=lambda x: x[1], reverse=True) + + for char_info in self.chars: + font = char_info[0] + font_size = char_info[1] + chars = char_info[2] + char_sizes = char_info[3] + char_width = sum([ch_size[0] for ch_size in char_sizes]) + char_height = max([ch_size[1] for ch_size in char_sizes]) + + box = self.assignBox(char_width, char_height) + if box is None: + x = random.randrange(self.width - char_width) + y = random.randrange(self.height - char_height) + box = (x, y, x + char_width, y + char_height) + + self.boxes.append(box) + + + def makeXML(self, char_list, img_filename, idx=None, output_path=None): + # char_info: (x, y, font_size, font_size) + + def indent(elem, level=0): + i = "\n" + level * " " + if len(elem): + if not elem.text or not elem.text.strip(): + elem.text = i + " " + if not elem.tail or not elem.tail.strip(): + elem.tail = i + for elem in elem: + indent(elem, level + 1) + if not elem.tail or not elem.tail.strip(): + elem.tail = i + else: + if level and (not elem.tail or not elem.tail.strip()): + elem.tail = i + + annotation = Element("annotation") + SubElement(annotation, "folder").text = output_path.split('/')[-1] + SubElement(annotation, "filename").text = str(img_filename.split('/')[-1]) + size = Element("size") + annotation.append(size) + SubElement(size, "width").text = str(self.width) + SubElement(size, "height").text = str(self.height) + SubElement(size, "depth").text = '3' + + for ch_info in char_list: + obj = Element("object") + annotation.append(obj) + SubElement(obj, "name").text = ch_info[1] + SubElement(obj, "fontsize").text = str(ch_info[2]) + bndbox = Element("bndbox") + obj.append(bndbox) + + x, y, width, height = ch_info[0] + + SubElement(bndbox, "xmin").text = str(x) + SubElement(bndbox, "ymin").text = str(y) + SubElement(bndbox, "xmax").text = str(width) + SubElement(bndbox, "ymax").text = str(height) + + indent(annotation) + ElementTree(annotation).write(os.path.join(output_path, 'annotations', '%d.xml' % idx), encoding='utf-8') + + def prn(self, chDataset=None, idx=None, output_path=None): + ''' + + :param chDataset: + :param filename: + :return: + ''' + # print('fill_area', self.fill_area, 'width', self.width, 'height', self.height) + # print('chars', [ch for ch in self.chars]) + # print('# of chars', sum([len(ch[2]) for ch in self.chars])) + + char_list = [] + + # base_font = ImageFont.truetype('/Library/Fonts/AppleGothic.ttf', 10) + if chDataset is not None: + # im = Image.new('RGB', (800, 600), color=(256, 256, 256)) + bg_name = chDataset.bg_list[idx % len(chDataset.bg_list)] + im = Image.open(bg_name) + + draw = ImageDraw.Draw(im) + for box, char_info in zip(self.boxes, self.chars): + font = char_info[0] + font_size = char_info[1] + chars = char_info[2] + + self.boxes.append(box) + imFont = chDataset.getIMFont(font, font_size) + draw.text(box[:2], ''.join(chars), font=imFont, fill=(0, 0, 0)) + + # label_box = [box[0], box[1] - 10] + # draw.text(label_box, ''.join(chars) + font.split('/')[-1], font=base_font, fill=(256, 0, 0)) + # draw.rectangle(box, outline=(256, 0, 0)) + + width_offset = 0 + for ch in chars: + ch_width, ch_height = imFont.getsize(ch) + ch_box = (box[0] + width_offset, box[1], box[0] + width_offset + ch_width, box[1] + ch_height) + # draw.rectangle(ch_box, outline=(0, 256, 0)) + width_offset += ch_width + char_list.append((ch_box, ch, font_size)) + + if idx is None or output_path is None: + import uuid + filename = '%s.jpg' % uuid.uuid4() + else: + filename = join(output_path, 'images', '%d.jpg' % idx) + im.save(filename) + + self.makeXML(char_list, filename, idx=idx, output_path=output_path) + + return sum([len(ch[2]) for ch in self.chars]) + + @classmethod + def generate(cls, chDataset, n_char=10, fill_ratio=0.2, width=800, height=600): + ''' + generate single image sample containing multiple chars + + :param n_char: number of char per font & size + :param ratio: fill area ratio of image with character region + :param width: width of image + :param height: height of image + :return: + ''' + if chDataset.isValid() is False: + return None + + sampleImg = Sample_IMG(fill_ratio=fill_ratio, width=width, height=height) + + while sampleImg.hasFill() is False and chDataset.hasMoreData() is True: + tuple_idx, chars = chDataset.getChars(n_char=n_char) + if len(chars) > 0: + (font, size) = chDataset.font_size_tuples[tuple_idx] + sampleImg.appendChars(chDataset, font, size, chars) + + sampleImg.placeStoreChars() + return sampleImg + +class CH_Dataset(): + def __init__(self, font_path=None, bg_path=None): + + self.char_list = [chr(begin + idx) for idx in range(end - begin + 1)] \ + + [x for x in (ascii_lowercase + ascii_uppercase)] + [str(x) for x in range(10)] + [x for x in + '~!@#$%^&*()_+-=<>?,.;:[]{}|'] + self.char_list = self.char_list[:1000] + + #with open('labels.txt', 'r+t', encoding='utf-8') as rf: + # content = rf.readlines() + #self.char_list = [x.strip() for x in content] + [str(x) for x in range(10)] + + self.font_list = [join(font_path, f) for f in listdir(font_path) if + isfile(join(font_path, f)) and f.find('.DS_Store') == -1] + + self.font_sizes = [10] * 5 + [20] * 10 + [30] * 7 + [50] * 5 + [100] * 2 + # self.font_sizes = [x for x in range(20,31)] * 80 + + self.bg_list = [join(bg_path, f) for f in listdir(bg_path) if + isfile(join(bg_path, f)) and f.find('.DS_Store') == -1] + + print('# of chars: {}'.format(len(self.char_list))) + print('# of fonts: {}'.format(len(self.font_list))) + print('# of background imgs: {}'.format(len(self.bg_list))) + print('total # of chars', len(self.char_list) * len(self.font_list) * len(self.font_sizes)) + + self.counter = 0 + self.gen_counter = 0 + self.fill_ratio = 0 + + self.font_size_tuples = None + self.char_list_list = None + self.counter_list = None + self.filled_counter = None + self.filled_counter_num = 0 + + self.imFont = dict() + + def isStop(self): + self.counter = self.counter + 1 + if self.counter > 50: + return True + else: + return False + + def getFont(self): + if self.font_list is not None and len(self.font_list) > 0 and len(self.font_sizes) > 0: + self.gen_counter += 1 + return self.font_list[self.gen_counter % len(self.font_list)], self.font_sizes[ + self.gen_counter % len(self.font_sizes)] + else: + return None, None + + def isValid(self): + if self.font_size_tuples is None or self.char_list_list is None or self.counter_list is None: + return False + else: + return True + + def hasMoreData(self): + if self.filled_counter_num >= len(self.counter_list): + return False + else: + return True + + def getIMFont(self, font, size): + if font not in self.imFont: + self.imFont[font] = dict() + + if size not in self.imFont[font]: + self.imFont[font][size] = ImageFont.truetype(font=font, size=size) + return self.imFont[font][size] + + def getChars(self, n_char=10): + # pick a random char list + idx = random.randrange(len(self.char_list_list)) + + # split first k chars(1~10) + k_chars = random.randrange(n_char) + 1 + + # get fist k chars + current_char_idx = self.counter_list[idx] + + # make sure target_idx <= len(self.char_list_list[idx]) + target_idx = current_char_idx + k_chars if current_char_idx + k_chars < len(self.char_list_list[idx]) else len( + self.char_list_list[idx]) + + # pick chars (current_char_idx:target_idx) + self.counter_list[idx] = target_idx + + if target_idx >= len(self.char_list_list[idx]): + self.filled_counter[idx] = True + self.filled_counter_num = sum(self.filled_counter) + return idx, self.char_list_list[idx][current_char_idx:target_idx] + + def generateSamples(self, n_char=7, fill_ratio=0.2, width=800, height=600, output_path=None): + ''' + generate samples: + 1. generate (font, size) tuple + 2. create N char list(N = # of (font, size) tuple) + 3. shuffle each char list + 4. create + + :param n_char: number of char per sentence + :param ratio: fill area ratio of image with character region + :param width: width of image + :param height: height of image + + :return: + ''' + self.font_size_tuples = [(f, s) for f in self.font_list for s in self.font_sizes] + self.char_list_list = [copy.deepcopy(self.char_list) for _ in range(len(self.font_size_tuples))] + [random.shuffle(ch_list) for ch_list in self.char_list_list] + self.counter_list = [0] * len(self.char_list_list) + + self.filled_counter = [counter >= len(char_list) for counter, char_list in zip(self.counter_list, self.char_list_list)] + + total_chars = 0 + results = list() + + if output_path is not None: + if not os.path.exists(output_path): + os.makedirs(output_path) + if not os.path.exists(join(output_path, 'images')): + os.makedirs(join(output_path, 'images')) + if not os.path.exists(join(output_path, 'annotations')): + os.makedirs(join(output_path, 'annotations')) + + output_idx = 0 + while self.hasMoreData() is True: + bg_name = self.bg_list[output_idx % len(self.bg_list)] + im = Image.open(bg_name) + width, height = im.size + n_char = math.floor(width / 100) - 1 + + gen_img = Sample_IMG.generate(self, n_char=n_char, fill_ratio=fill_ratio, width=width, height=height) + if gen_img: + total_chars += gen_img.prn(chDataset=self, idx=output_idx, output_path=output_path) + results.append(gen_img) + output_idx += 1 + + if not (output_idx % 100): + print("Image #: {} created".format(output_idx)) + + print("Total # of images: ", output_idx) + + NUMBER_OF_IMAGES = output_idx + # split train/val/test set + shuffled_index = list(range(NUMBER_OF_IMAGES)) + random.shuffle(shuffled_index) + + # splitting train/validation/test set (unit: %) + TRAIN_SET = 98 + VALID_SET = 1 + TEST_SET = 1 + num_train = int(NUMBER_OF_IMAGES * TRAIN_SET / (TRAIN_SET + VALID_SET + TEST_SET)) + num_valid = int(NUMBER_OF_IMAGES * VALID_SET / (TRAIN_SET + VALID_SET + TEST_SET)) + num_test = NUMBER_OF_IMAGES - num_train - num_valid + + with open(join(output_path, 'train.txt'), "w") as wf: + for index in shuffled_index[0:num_train]: + wf.write(str(index) + '\n') + + with open(join(output_path, 'val.txt'), "w") as wf: + for index in shuffled_index[num_train:num_train + num_valid]: + wf.write(str(index) + '\n') + + with open(join(output_path, "trainval.txt"), "w") as wf: + for index in shuffled_index[0:num_train + num_valid]: + wf.write(str(index) + '\n') + + with open(join(output_path, 'test.txt'), "w") as wf: + for index in shuffled_index[num_train + num_valid:]: + wf.write(str(index) + '\n') + + with open(join(output_path, 'labels.txt'), "w") as wf: + for char in self.char_list: + wf.write(str(char) + '\n') + + print("Train / Valid / Test : {} / {} / {}".format(num_train, num_valid, num_test)) + print("Output path: {}".format(output_path)) + + return results + + +chd = CH_Dataset(font_path='fonts', bg_path='backgrounds') +chd.generateSamples(output_path='data/fontdataset')