Skip to content

Commit

Permalink
ocl: improved tuner script
Browse files Browse the repository at this point in the history
* Ensure check-pointed parameters produce correct results
* Turn last correct checkpoint into final parameter set
  • Loading branch information
hfp committed Oct 5, 2023
1 parent fb93347 commit 39583bb
Showing 1 changed file with 55 additions and 53 deletions.
108 changes: 55 additions & 53 deletions src/acc/opencl/smm/tune_multiply.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from opentuner import ConfigurationManipulator
from opentuner import MeasurementInterface
from opentuner import Result
from signal import signal, SIGINT, SIG_DFL
from signal import signal, getsignal, SIGINT, SIG_DFL
import copy
import json
import glob
Expand Down Expand Up @@ -60,9 +60,9 @@ def manipulator(self):
self.nz = self.al = self.tb = self.tc = None
self.ap = self.aa = self.ab = self.ac = None
self.xf = os.getenv("OPENCL_LIBSMM_SMM_XF")
self.gfbase = self.gfsave = self.gflops = 0
self.typename = self.typeid = None
self.device = self.size = None
self.gfbase = self.gflops = 0
self.config = None
self.exename = "acc_bench_smm"
self.exepath = os.path.join(
Expand Down Expand Up @@ -294,7 +294,7 @@ def run(self, desired_result, input, limit):
self.gflops = gflops
if 0 == self.gfbase: # seed configuration
self.gfbase = gflops
self.save_final_config(desired_result.configuration, final=False)
self.save_config(desired_result.configuration, final=False)
kernelreq = round(
(100.0 * config["BM"] * config["BN"]) / (self.mnk[0] * self.mnk[1])
)
Expand Down Expand Up @@ -441,75 +441,78 @@ def merge_jsons(self, filenames):
print("Renamed {} to {}.".format(self.args.csvfile, backup))
os.rename(self.args.csvfile, backup)

def save_final_config(self, configuration, final=True):
def save_config(self, configuration, final=True):
"""Called at termination"""
if 0 < self.gflops and configuration:
# extend result for easier reuse later
config = configuration.data
config["DEVICE"] = self.device
config["GFLOPS"] = self.gflops if not self.args.nogflops else 0
config["TYPEID"] = self.typeid
config["M"] = self.mnk[0]
config["N"] = self.mnk[1]
config["K"] = self.mnk[2]
config["S"] = self.size
filepattern = "{}-*.json".format(default_basename)
filenames = (
glob.glob(
os.path.normpath(os.path.join(self.args.jsondir, filepattern))
)
if final
else None
)
if self.handle_sigint != getsignal(SIGINT):
signal(SIGINT, SIG_DFL) # avoid recursion
if 0 >= self.gflops or not configuration:
return # nothing to save
config = configuration.data
cfgenv = self.environment(config)
result = self.run_result["returncode"] if self.run_result else 1
if 0 == result and 0 == self.args.check: # enable CHECKing result
self.run_result = self.launch(cfgenv + ["CHECK=1"])
result = self.run_result["returncode"] if self.run_result else 1
# extend result for easier reuse
config["DEVICE"] = self.device
config["GFLOPS"] = self.gflops if not self.args.nogflops else 0
config["TYPEID"] = self.typeid
config["M"] = self.mnk[0]
config["N"] = self.mnk[1]
config["K"] = self.mnk[2]
config["S"] = self.size
filepattern = "{}-*.json".format(default_basename)
filenames = (
glob.glob(os.path.normpath(os.path.join(self.args.jsondir, filepattern)))
if final
else None
)
filedot = os.path.join(self.args.jsondir, ".{}.json".format(self.args.label))
# check return code (consider not saving parameters)
if 0 == result:
self.gfsave = self.gflops
# self.manipulator().save_to_file(config, filename)
with open(
os.path.join(self.args.jsondir, ".{}.json".format(self.args.label)),
"w",
) as file:
with open(filedot, "w") as file:
cfg = config
if "XF" in config and 0 == config["XF"]:
cfg = copy.deepcopy(config)
del cfg["XF"]
json.dump(cfg, file, sort_keys=True)
file.write("\n") # append newline at EOF
if final:
if not filenames and glob.glob(self.args.csvfile):
print(
"WARNING: no JSON file found but {} exists.".format(
self.args.csvfile
)
)
filename = os.path.normpath(
os.path.join(
self.args.jsondir,
"{}-{}gflops.json".format(self.args.label, round(self.gflops)),
elif not final: # incorrect result
failed = " ".join(map(str, cfgenv)).replace("OPENCL_LIBSMM_SMM_", "")
print("FAILED: {}".format(failed))
return
if final:
if not filenames and glob.glob(self.args.csvfile):
print(
"WARNING: no JSON file found but {} exists.".format(
self.args.csvfile
)
)
os.rename(
os.path.join(self.args.jsondir, ".{}.json".format(self.args.label)),
filename,
filename = os.path.normpath(
os.path.join(
self.args.jsondir,
"{}-{}gflops.json".format(self.args.label, round(self.gfsave)),
)
)
if os.path.exists(filedot):
os.rename(filedot, filename)
if filename not in filenames:
filenames.append(filename)
self.merge_jsons(filenames)
speedup = round(
(self.gflops / self.gfbase) if 0 < self.gfbase else 0, 1
(self.gfsave / self.gfbase) if 0 < self.gfbase else 0, 1
)
print(
"Result{} was written to {}".format(
" ({}x over seed)".format(speedup) if 1 < speedup else "",
filename,
)
)
if ( # avoid recursion (self.handle_sigint != getsignal(SIGINT))
self.run_result and 0 == self.run_result["returncode"]
) and 0 == self.args.check:
signal(SIGINT, SIG_DFL)
self.run_result = self.launch(
self.environment(config) + ["CHECK=1"]
)
if self.run_result and 0 != self.run_result["returncode"]:
print("WARNING: tuned result seems to be incorrect!")
else:
print("WARNING: tuned result seems to be incorrect!")
exit(0)

def handle_sigint(self, signum, frame):
"""Handle SIGINT or CTRL-C"""
Expand All @@ -518,7 +521,7 @@ def handle_sigint(self, signum, frame):
self.mnk[0], self.mnk[1], self.mnk[2]
)
)
self.save_final_config(self.config)
self.save_config(self.config)
exit(1)


Expand Down Expand Up @@ -788,8 +791,7 @@ def handle_sigint(self, signum, frame):
os.environ["OPENCL_LIBSMM_SMM_LU"] = "{}".format(args.lu)
if 0 == args.mb:
args.mb = 64
# additional/depending arguments
try:
SmmTuner.main(args)
except: # noqa: E722
except Exception:
pass

0 comments on commit 39583bb

Please sign in to comment.