#!/usr/bin/env escript
%% -*- mode: erlang;erlang-indent-level: 4;indent-tabs-mode: nil -*-
%% ex: ft=erlang ts=4 sw=4 et
%% -------------------------------------------------------------------
%%
%% nodetool: Helper Script for interacting with live nodes
%%
%% -------------------------------------------------------------------
-mode(compile).

main(Args) ->
    case os:type() of
        {win32, nt} -> ok;
        _nix ->
            case init:get_argument(start_epmd) of
                {ok, [["true"]]} ->
                    ok = start_epmd();
                _ ->
                    {module, nodetool_epmd} = compile_epmd_module()
            end
    end,
    ok = do_with_halt(Args, "mnesia_dir", fun create_mnesia_dir/2),
    ok = do_with_halt(Args, "mergeconf", fun mergeconf/3),
    ok = do_with_halt(Args, "chkconfig", fun("-config", X) -> chkconfig(X) end),
    ok = do_with_halt(Args, "chkconfig", fun chkconfig/1),
    Args1 = do_with_ret(Args, "-name",
                        fun(TargetName) ->
                                ThisNode = append_node_suffix(TargetName, "_maint_"),
                                {ok, _} = net_kernel:start([ThisNode, longnames]),
                                put(target_node, nodename(TargetName))
                        end),
    Args2 = do_with_ret(Args1, "-sname",
                        fun(TargetName) ->
                                ThisNode = append_node_suffix(TargetName, "_maint_"),
                                {ok, _} = net_kernel:start([ThisNode, shortnames]),
                                put(target_node, nodename(TargetName))
                        end),
    RestArgs = do_with_ret(Args2, "-setcookie",
                           fun(Cookie) ->
                                   erlang:set_cookie(node(), list_to_atom(Cookie))
                           end),

    [application:start(App) || App <- [crypto, public_key, ssl]],
    TargetNode = get(target_node),

    %% See if the node is currently running  -- if it's not, we'll bail
    case {net_kernel:hidden_connect_node(TargetNode), net_adm:ping(TargetNode)} of
        {true, pong} ->
            ok;
        {false, pong} ->
            io:format(standard_error, "Failed to connect to node ~p .\n", [TargetNode]),
            halt(1);
        {_, pang} ->
            io:format(standard_error, "Node ~p not responding to pings.\n", [TargetNode]),
            halt(1)
    end,

    case RestArgs of
        ["getpid"] ->
            io:format("~p\n", [list_to_integer(rpc:call(TargetNode, os, getpid, []))]);
        ["ping"] ->
            %% If we got this far, the node already responsed to a ping, so just dump
            %% a "pong"
            io:format("pong\n");
        ["stop"] ->
            rpc:call(TargetNode, emqx_plugins, unload, [], 60000),
            io:format("~p\n", [rpc:call(TargetNode, init, stop, [], 60000)]);
        ["restart", "-config", ConfigFile | _RestArgs1] ->
            io:format("~p\n", [rpc:call(TargetNode, emqx, restart, [ConfigFile], 60000)]);
        ["reboot", "-config", ConfigFile | _RestArgs1] ->
            io:format("~p\n", [rpc:call(TargetNode, emqx, restart, [ConfigFile], 60000)]);
        ["rpc", Module, Function | RpcArgs] ->
            case rpc:call(TargetNode, list_to_atom(Module), list_to_atom(Function),
                          [RpcArgs], 60000) of
                ok ->
                    ok;
                {error, cmd_not_found} ->
                    halt(1);
                {error, Reason} ->
                    io:format("RPC to ~s error: ~p\n", [TargetNode, Reason]),
                    halt(1);
                {badrpc, Reason} ->
                    io:format("RPC to ~s failed: ~p\n", [TargetNode, Reason]),
                    halt(1);
                _ ->
                    halt(1)
            end;
        ["rpc_infinity", Module, Function | RpcArgs] ->
            case rpc:call(TargetNode, list_to_atom(Module), list_to_atom(Function), [RpcArgs], infinity) of
                ok ->
                    ok;
                {badrpc, Reason} ->
                    io:format("RPC to ~p failed: ~p\n", [TargetNode, Reason]),
                    halt(1);
                _ ->
                    halt(1)
            end;
        ["rpcterms", Module, Function | ArgsAsString] ->
            case rpc:call(TargetNode, list_to_atom(Module), list_to_atom(Function),
                          consult(lists:flatten(ArgsAsString)), 60000) of
                {badrpc, Reason} ->
                    io:format("RPC to ~p failed: ~p\n", [TargetNode, Reason]),
                    halt(1);
                Other ->
                    io:format("~p\n", [Other])
            end;
        ["eval" | ListOfArgs] ->
            % shells may process args into more than one, and end up stripping
            % spaces, so this converts all of that to a single string to parse
            String = binary_to_list(
                      list_to_binary(
                        join(ListOfArgs," ")
                      )
                    ),

            % then just as a convenience to users, if they forgot a trailing
            % '.' add it for them.
            Normalized =
              case lists:reverse(String) of
                [$. | _] -> String;
                R -> lists:reverse([$. | R])
              end,

            % then scan and parse the string
            {ok, Scanned, _} = erl_scan:string(Normalized),
            {ok, Parsed } = erl_parse:parse_exprs(Scanned),

            % and evaluate it on the remote node
            case rpc:call(TargetNode, erl_eval, exprs, [Parsed, [] ]) of
                {value, Value, _} ->
                    io:format ("~p\n",[Value]);
                {badrpc, Reason} ->
                    io:format("RPC to ~p failed: ~p\n", [TargetNode, Reason]),
                    halt(1)
            end;
        Other ->
            io:format("Other: ~p\n", [Other]),
            io:format("Usage: nodetool {genconfig, chkconfig|getpid|ping|stop|restart|reboot|rpc|rpc_infinity|rpcterms|eval [Terms]} [RPC]\n")
    end,
    net_kernel:stop().

do_with_ret(Args, Name, Handler) ->
    {arity, Arity} = erlang:fun_info(Handler, arity),
    case take_args(Args, Name, Arity) of
        false ->
            Args;
        {Args1, Rest} ->
            _ = erlang:apply(Handler, Args1),
            Rest
    end.

do_with_halt(Args, Name, Handler) ->
    {arity, Arity} = erlang:fun_info(Handler, arity),
    case take_args(Args, Name, Arity) of
        false ->
            ok;
        {Args1, _Rest} ->
            erlang:apply(Handler, Args1), %% should halt
            io:format(standard_error, "~s handler did not halt", [Name]),
            halt(?LINE)
    end.

%% Return option args list if found, otherwise 'false'.
take_args(Args, OptName, 0) ->
    lists:member(OptName, Args) andalso [];
take_args(Args, OptName, OptArity) ->
    take_args(Args, OptName, OptArity, _Scanned = []).

take_args([], _, _, _) -> false; %% no such option
take_args([Name | Rest], Name, Arity, Scanned) ->
    length(Rest) >= Arity orelse error({not_enough_args_for, Name}),
    {Result, Tail} = lists:split(Arity, Rest),
    {Result, lists:reverse(Scanned) ++ Tail};
take_args([Other | Rest], Name, Arity, Scanned) ->
    take_args(Rest, Name, Arity, [Other | Scanned]).

start_epmd() ->
    [] = os:cmd("\"" ++ epmd_path() ++ "\" -daemon"),
    ok.

epmd_path() ->
    ErtsBinDir = filename:dirname(escript:script_name()),
    Name = "epmd",
    case os:find_executable(Name, ErtsBinDir) of
        false ->
            case os:find_executable(Name) of
                false ->
                    io:format("Could not find epmd.~n"),
                    halt(1);
                GlobalEpmd ->
                    GlobalEpmd
            end;
        Epmd ->
            Epmd
    end.

compile_epmd_module() ->
    compile([
      "-module(nodetool_epmd).",
      "-export([start_link/0]).",
      "-export([port_please/2, port_please/3]).",
      "-export([names/0, names/1]).",
      "-export([register_node/2, register_node/3]).",
      "start_link() -> ignore.",
      "register_node(_Name, _Port) ->
         Creation = rand:uniform(3),
         {ok, Creation}.",
      "register_node(Name, Port, _Family) ->
         register_node(Name, Port).",
      "port_please(Name, _IP) ->
         Port = dist_port(Name),
         Version = 5,
         {port, Port, Version}.",
      "port_please(Name, IP, _Timeout) ->
         port_please(Name, IP).",
      "names() ->
         {ok, H} = inet:gethostname(),
         names(H).",
      "names(_Hostname) ->
         {error, address}.",
      "dist_port(Name) when is_atom(Name) ->
         dist_port(atom_to_list(Name));
       dist_port(Name) when is_list(Name) ->
         NodeName = re:replace(Name, \"@.*$\", \"\"),
         Offset =
           case re:run(NodeName, \"[0-9]+$\", [{capture, first, list}]) of
             nomatch -> 0;
             {match, [OffsetAsString]} ->
               list_to_integer(OffsetAsString)
             end,
         4370 + Offset."
    ]).

compile(Exprs) ->
    Forms = lists:map(
              fun(Expr) ->
                  {ok, Tokens, _} = erl_scan:string(lists:flatten(Expr)),
                  {ok, Form} = erl_parse:parse_form(Tokens),
                  Form
              end, Exprs),
    {ok, Mod, Code} = compile:forms(Forms, []),
    code:load_binary(Mod, "nofile", Code).

nodename(Name) ->
    case re:split(Name, "@", [{return, list}, unicode]) of
        [_Node, _Host] ->
            list_to_atom(Name);
        [Node] ->
            [_, Host] = re:split(atom_to_list(node()), "@", [{return, list}, unicode]),
            list_to_atom(lists:concat([Node, "@", Host]))
    end.

append_node_suffix(Name, Suffix) ->
    case re:split(Name, "@", [{return, list}, unicode]) of
        [Node, Host] ->
            list_to_atom(lists:concat([Node, Suffix, os:getpid(), "@", Host]));
        [Node] ->
            list_to_atom(lists:concat([Node, Suffix, os:getpid()]))
    end.

%% For windows???
create_mnesia_dir(DataDir, NodeName) ->
    MnesiaDir = filename:join(DataDir, NodeName),
    file:make_dir(MnesiaDir),
    io:format("~s", [MnesiaDir]),
    halt(0).

mergeconf(EmqCfg, PluginsEtc, DestDir) ->
    filelib:ensure_dir(DestDir),%% support brew on macosx
    AppConf = filename:join(DestDir, appconf()),
    case file:open(AppConf, [write]) of
        {ok, DestFd} ->
            CfgFiles = [EmqCfg | cfgfiles(PluginsEtc)],
            lists:foreach(fun(CfgFile) ->
                            {ok, CfgData} = file:read_file(CfgFile),
                            file:write(DestFd, CfgData),
                            file:write(DestFd, <<"\n">>)
                          end, CfgFiles),
            file:close(DestFd),
            io:format("~s", [AppConf]),
            halt(0);
        {error, Error} ->
            io:format(standard_error, ["Error opening file: ", AppConf, " ", file:format_error(Error), "\n"], []),
            halt(1)
    end.

appconf() ->
    {{Y, M, D}, {H, MM, S}} = calendar:local_time(),
    lists:flatten(
      io_lib:format(
        "app.~4..0w.~2..0w.~2..0w.~2..0w.~2..0w.~2..0w.conf", [Y, M, D, H, MM, S])).

cfgfiles(Dir) ->
    [filename:join(Dir, CfgFile) || CfgFile <- filelib:wildcard("*.conf", Dir)].

chkconfig(File) ->
    case file:consult(File) of
        {ok, Terms} ->
            case validate(Terms) of
                ok ->
                    halt(0);
                {error, Problems} ->
                    lists:foreach(fun print_issue/1, Problems),
                    %% halt(1) if any problems were errors
                    halt(case [x || {error, _} <- Problems] of
                             [] -> 0;
                             _  -> 1
                         end)
            end;
        {error, {Line, Mod, Term}} ->
            io:format(standard_error, ["Error on line ", file:format_error({Line, Mod, Term}), "\n"], []),
            halt(1);
        {error, Error} ->
            io:format(standard_error, ["Error reading config file: ", File, " ", file:format_error(Error), "\n"], []),
            halt(1)
    end.

%%
%% Given a string or binary, parse it into a list of terms, ala file:consult/0
%%
consult(Str) when is_list(Str) ->
    consult([], Str, []);
consult(Bin) when is_binary(Bin)->
    consult([], binary_to_list(Bin), []).

consult(Cont, Str, Acc) ->
    case erl_scan:tokens(Cont, Str, 0) of
        {done, Result, Remaining} ->
            case Result of
                {ok, Tokens, _} ->
                    {ok, Term} = erl_parse:parse_term(Tokens),
                    consult([], Remaining, [Term | Acc]);
                {eof, _Other} ->
                    lists:reverse(Acc);
                {error, Info, _} ->
                    {error, Info}
            end;
        {more, Cont1} ->
            consult(Cont1, eof, Acc)
    end.

%%
%% Validation functions for checking the app.config
%%
validate([Terms]) ->
    Results = [ValidateFun(Terms) || ValidateFun <- get_validation_funs()],
    Failures = [Res || Res <- Results, Res /= true],
    case Failures of
        [] ->
            ok;
        _ ->
            {error, Failures}
    end.

%% Some initial and basic checks for the app.config file
get_validation_funs() ->
    [ ].

print_issue({warning, Warning}) ->
    io:format(standard_error, "Warning in app.config: ~s~n", [Warning]);
print_issue({error, Error}) ->
    io:format(standard_error, "Error in app.config: ~s~n", [Error]).

%% string:join/2 copy; string:join/2 is getting obsoleted
%% and replaced by lists:join/2, but lists:join/2 is too new
%% for version support (only appeared in 19.0) so it cannot be
%% used. Instead we just adopt join/2 locally and hope it works
%% for most unicode use cases anyway.
join([], Sep) when is_list(Sep) ->
    [];
join([H|T], Sep) ->
    H ++ lists:append([Sep ++ X || X <- T]).
