This commit is contained in:
nttstar
2018-04-27 14:53:58 +08:00
parent bd45b710c1
commit f7c83cac4f
3 changed files with 27 additions and 12 deletions

View File

@@ -24,6 +24,10 @@ def main(args):
stat = [0,0]
for ds in ['ms1m', 'megaage', 'imdb']:
for n in ['train', 'val']:
if ds=='ms1m' or ds=='imdb':
continue
if n=='val' and ds!='megaage':
continue
repeat = 1
if n=='train' and ds=='megaage':
repeat = 1
@@ -34,7 +38,7 @@ def main(args):
if n=='val':
writer = val_writer
widx = val_widx
path = os.path.join(ds, '%s.rec'%n)
path = os.path.join(args.input, ds, '%s.rec'%n)
if not os.path.exists(path):
continue
imgrec = mx.recordio.MXIndexedRecordIO(path[:-3]+'idx', path, 'r') # pylint: disable=redefined-variable-type
@@ -93,6 +97,7 @@ def main(args):
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='do dataset merge')
# general
parser.add_argument('--input', default='', type=str, help='')
parser.add_argument('--output', default='', type=str, help='')
args = parser.parse_args()
main(args)