#!/usr/bin/perl 

use POSIX qw /floor/;

# Usage: gen_sample FRAC ROUNDS
# 1-FRAC is the fraction of labeled used for validation
# ROUNDS is the number of batches
# => val_size = ROUNDS * floor(floor(VAL_SIZE*#labeled)/ROUNDS)

$frac = $ARGV[0];
$rounds = $ARGV[1];

open IN, "labeled";
@labeled = <IN>;

$val_size = floor((1.0-$frac)*($#labeled+1));
if($rounds > $val_size) { $rounds = $val_size; }

$batch_size = floor($val_size / $rounds);
$val_size = $rounds * $batch_size;

open OUTPUT, '>', "train";
print OUTPUT @labeled[0..($#labeled-$val_size)];
close OUTPUT;

for($i=0,$end=$#labeled-$val_size; $i<$rounds; $i++) {
  open OUTPUT, '>', "test_" . ($i+1);
  $beg = $end+1;
  $end = $end+$batch_size;
  print OUTPUT @labeled[$beg..$end];
  close OUTPUT;
}
print $rounds;
