/**
 * Copyright 2014 Nikita Koksharov, Nickolay Borbit
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.redisson.client.handler;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import org.redisson.client.RedisAskException;
import org.redisson.client.RedisException;
import org.redisson.client.RedisLoadingException;
import org.redisson.client.RedisMovedException;
import org.redisson.client.RedisOutOfMemoryException;
import org.redisson.client.RedisPubSubConnection;
import org.redisson.client.RedisTimeoutException;
import org.redisson.client.codec.StringCodec;
import org.redisson.client.protocol.CommandData;
import org.redisson.client.protocol.CommandsData;
import org.redisson.client.protocol.Decoder;
import org.redisson.client.protocol.QueueCommand;
import org.redisson.client.protocol.RedisCommand.ValueType;
import org.redisson.client.protocol.decoder.MultiDecoder;
import org.redisson.client.protocol.pubsub.Message;
import org.redisson.client.protocol.pubsub.PubSubMessage;
import org.redisson.client.protocol.pubsub.PubSubPatternMessage;
import org.redisson.client.protocol.pubsub.PubSubStatusMessage;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import io.netty.buffer.ByteBuf;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.ReplayingDecoder;
import io.netty.util.CharsetUtil;
import io.netty.util.concurrent.Promise;
import io.netty.util.internal.PlatformDependent;

/**
 * Redis protocol command decoder
 *
 * Code parts from Sam Pullara
 *
 * @author Nikita Koksharov
 *
 */
public class CommandDecoder extends ReplayingDecoder<State> {

    private final Logger log = LoggerFactory.getLogger(getClass());

    public static final char CR = '\r';
    public static final char LF = '\n';
    private static final char ZERO = '0';

    // It is not needed to use concurrent map because responses are coming consecutive
    private final Map<String, MultiDecoder<Object>> pubSubMessageDecoders = new HashMap<String, MultiDecoder<Object>>();
    private final Map<String, CommandData<Object, Object>> pubSubChannels = PlatformDependent.newConcurrentHashMap();

    public void addPubSubCommand(String channel, CommandData<Object, Object> data) {
        pubSubChannels.put(channel, data);
    }

    @Override
    protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception {
        QueueCommand data = ctx.channel().attr(CommandsQueue.CURRENT_COMMAND).get();

        Decoder<Object> currentDecoder = null;
        if (data == null) {
            currentDecoder = StringCodec.INSTANCE.getValueDecoder();
        }

        if (state() == null) {
            state(new State());

            if (log.isTraceEnabled()) {
                log.trace("channel: {} message: {}", ctx.channel(), in.toString(0, in.writerIndex(), CharsetUtil.UTF_8));
            }
        }
        state().setDecoderState(null);

        if (data == null) {
            decode(in, null, null, ctx.channel(), currentDecoder);
        } else if (data instanceof CommandData) {
            CommandData<Object, Object> cmd = (CommandData<Object, Object>)data;
            try {
//                if (state().getSize() > 0) {
//                    List<Object> respParts = new ArrayList<Object>();
//                    if (state().getRespParts() != null) {
//                        respParts = state().getRespParts();
//                    }
//                    decodeMulti(in, cmd, null, ctx.channel(), currentDecoder, state().getSize(), respParts, true);
//                } else {
                    decode(in, cmd, null, ctx.channel(), currentDecoder);
//                }
            } catch (IOException e) {
                cmd.getPromise().tryFailure(e);
            }
        } else if (data instanceof CommandsData) {
            CommandsData commands = (CommandsData)data;

            handleCommandsDataResponse(ctx, in, data, currentDecoder, commands);
            return;
        }

        ctx.pipeline().get(CommandsQueue.class).sendNextCommand(ctx.channel());

        state(null);
    }

    private void handleCommandsDataResponse(ChannelHandlerContext ctx, ByteBuf in, QueueCommand data,
            Decoder<Object> currentDecoder, CommandsData commands) {
        int i = state().getIndex();

        RedisException error = null;
        while (in.writerIndex() > in.readerIndex()) {
            CommandData<Object, Object> cmd = null;
            try {
                checkpoint();
                state().setIndex(i);
                cmd = (CommandData<Object, Object>) commands.getCommands().get(i);
                decode(in, cmd, null, ctx.channel(), currentDecoder);
                i++;
            } catch (IOException e) {
                cmd.getPromise().tryFailure(e);
            }
            if (!cmd.getPromise().isSuccess()) {
                if (!(cmd.getPromise().cause() instanceof RedisMovedException 
                        || cmd.getPromise().cause() instanceof RedisAskException
                            || cmd.getPromise().cause() instanceof RedisLoadingException)) {
                    error = (RedisException) cmd.getPromise().cause();
                }
            }
        }

        if (i == commands.getCommands().size()) {
            Promise<Void> promise = commands.getPromise();
            if (error != null) {
                if (!promise.tryFailure(error) && promise.cause() instanceof RedisTimeoutException) {
                    log.warn("response has been skipped due to timeout! channel: {}, command: {}", ctx.channel(), data);
                }
            } else {
                if (!promise.trySuccess(null) && promise.cause() instanceof RedisTimeoutException) {
                    log.warn("response has been skipped due to timeout! channel: {}, command: {}", ctx.channel(), data);
                }
            }

            ctx.pipeline().get(CommandsQueue.class).sendNextCommand(ctx.channel());

            state(null);
        } else {
            checkpoint();
            state().setIndex(i);
        }
    }

    private void decode(ByteBuf in, CommandData<Object, Object> data, List<Object> parts, Channel channel, Decoder<Object> currentDecoder) throws IOException {
        int code = in.readByte();
        if (code == '+') {
            String result = in.readBytes(in.bytesBefore((byte) '\r')).toString(CharsetUtil.UTF_8);
            in.skipBytes(2);

            handleResult(data, parts, result, false, channel);
        } else if (code == '-') {
            String error = in.readBytes(in.bytesBefore((byte) '\r')).toString(CharsetUtil.UTF_8);
            in.skipBytes(2);

            if (error.startsWith("MOVED")) {
                String[] errorParts = error.split(" ");
                int slot = Integer.valueOf(errorParts[1]);
                String addr = errorParts[2];
                data.getPromise().tryFailure(new RedisMovedException(slot, addr));
            } else if (error.startsWith("ASK")) {
                String[] errorParts = error.split(" ");
                int slot = Integer.valueOf(errorParts[1]);
                String addr = errorParts[2];
                data.getPromise().tryFailure(new RedisAskException(slot, addr));
            } else if (error.startsWith("LOADING")) {
                data.getPromise().tryFailure(new RedisLoadingException(error
                        + ". channel: " + channel + " data: " + data));
            } else if (error.startsWith("OOM")) {
                data.getPromise().tryFailure(new RedisOutOfMemoryException(error.split("OOM ")[1]
                        + ". channel: " + channel + " data: " + data));
            } else if (error.contains("-OOM ")) {
                data.getPromise().tryFailure(new RedisOutOfMemoryException(error.split("-OOM ")[1]
                        + ". channel: " + channel + " data: " + data));
            } else {
                if (data != null) {
                    data.getPromise().tryFailure(new RedisException(error + ". channel: " + channel + " command: " + data));
                } else {
                    log.error("Error: {} channel: {} data: {}", error, channel, data);
                }
            }
        } else if (code == ':') {
            String status = in.readBytes(in.bytesBefore((byte) '\r')).toString(CharsetUtil.UTF_8);
            in.skipBytes(2);
            Object result = Long.valueOf(status);
            handleResult(data, parts, result, false, channel);
        } else if (code == '$') {
            ByteBuf buf = readBytes(in);
            Object result = null;
            if (buf != null) {
                result = decoder(data, parts, currentDecoder).decode(buf, state());
            }
            handleResult(data, parts, result, false, channel);
        } else if (code == '*') {
            long size = readLong(in);
            List<Object> respParts = new ArrayList<Object>();
            boolean top = false; 
//            if (state().trySetSize(size)) {
//                state().setRespParts(respParts);
//                top = true;
//            }

            decodeMulti(in, data, parts, channel, currentDecoder, size, respParts, top);
        } else {
            throw new IllegalStateException("Can't decode replay " + (char)code);
        }
    }

    private void decodeMulti(ByteBuf in, CommandData<Object, Object> data, List<Object> parts,
            Channel channel, Decoder<Object> currentDecoder, long size, List<Object> respParts, boolean top)
                    throws IOException {
        for (int i = respParts.size(); i < size; i++) {
            decode(in, data, respParts, channel, currentDecoder);
//            if (top) {
//                checkpoint();
//            }
        }

        MultiDecoder<Object> decoder = messageDecoder(data, respParts, channel);
        if (decoder == null) {
            return;
        }

        Object result = decoder.decode(respParts, state());


        if (result instanceof Message) {
            // store current message index
            checkpoint();

            handleMultiResult(data, null, channel, result);
            // has next messages?
            if (in.writerIndex() > in.readerIndex()) {
                decode(in, data, null, channel, currentDecoder);
            }
        } else {
            handleMultiResult(data, parts, channel, result);
        }
    }

    private void handleMultiResult(CommandData<Object, Object> data, List<Object> parts,
            Channel channel, Object result) {
        if (data != null) {
            handleResult(data, parts, result, true, channel);
        } else {
            if (result instanceof PubSubStatusMessage) {
                String channelName = ((PubSubStatusMessage) result).getChannel();
                CommandData<Object, Object> d = pubSubChannels.get(channelName);
                if (Arrays.asList("PSUBSCRIBE", "SUBSCRIBE").contains(d.getCommand().getName())) {
                    pubSubChannels.remove(channelName);
                    pubSubMessageDecoders.put(channelName, d.getMessageDecoder());
                }
                if (Arrays.asList("PUNSUBSCRIBE", "UNSUBSCRIBE").contains(d.getCommand().getName())) {
                    pubSubChannels.remove(channelName);
                    pubSubMessageDecoders.remove(channelName);
                }
            }

            RedisPubSubConnection pubSubConnection = RedisPubSubConnection.getFrom(channel);
            if (result instanceof PubSubStatusMessage) {
                pubSubConnection.onMessage((PubSubStatusMessage) result);
            } else if (result instanceof PubSubMessage) {
                pubSubConnection.onMessage((PubSubMessage) result);
            } else {
                pubSubConnection.onMessage((PubSubPatternMessage) result);
            }
        }
    }

    private void handleResult(CommandData<Object, Object> data, List<Object> parts, Object result, boolean multiResult, Channel channel) {
        if (data != null) {
            if (multiResult) {
                result = data.getCommand().getConvertor().convertMulti(result);
            } else {
                result = data.getCommand().getConvertor().convert(result);
            }
        }
        if (parts != null) {
            parts.add(result);
        } else {
            if (!data.getPromise().trySuccess(result) && data.getPromise().cause() instanceof RedisTimeoutException) {
                log.warn("response has been skipped due to timeout! channel: {}, command: {}, result: {}", channel, data, result);
            }
        }
    }

    private MultiDecoder<Object> messageDecoder(CommandData<Object, Object> data, List<Object> parts, Channel channel) {
        if (data == null) {
            if (Arrays.asList("subscribe", "psubscribe", "punsubscribe", "unsubscribe").contains(parts.get(0))) {
                String channelName = (String) parts.get(1);
                CommandData<Object, Object> commandData = pubSubChannels.get(channelName);
                if (commandData == null) {
                    return null;
                }
                return commandData.getCommand().getReplayMultiDecoder();
            } else if (parts.get(0).equals("message")) {
                String channelName = (String) parts.get(1);
                return pubSubMessageDecoders.get(channelName);
            } else if (parts.get(0).equals("pmessage")) {
                String patternName = (String) parts.get(1);
                return pubSubMessageDecoders.get(patternName);
            }
        }

        return data.getCommand().getReplayMultiDecoder();
    }

    private Decoder<Object> decoder(CommandData<Object, Object> data, List<Object> parts, Decoder<Object> currentDecoder) {
        if (data == null) {
            if (parts.size() == 2 && parts.get(0).equals("message")) {
                String channelName = (String) parts.get(1);
                return pubSubMessageDecoders.get(channelName);
            }
            if (parts.size() == 3 && parts.get(0).equals("pmessage")) {
                String patternName = (String) parts.get(1);
                return pubSubMessageDecoders.get(patternName);
            }
            return currentDecoder;
        }

        Decoder<Object> decoder = data.getCommand().getReplayDecoder();
        if (parts != null) {
            MultiDecoder<Object> multiDecoder = data.getCommand().getReplayMultiDecoder();
            if (multiDecoder.isApplicable(parts.size(), state())) {
                decoder = multiDecoder;
            }
        }
        if (decoder == null) {
            if (data.getCommand().getOutParamType() == ValueType.MAP) {
                if (parts.size() % 2 != 0) {
                    decoder = data.getCodec().getMapValueDecoder();
                } else {
                    decoder = data.getCodec().getMapKeyDecoder();
                }
            } else if (data.getCommand().getOutParamType() == ValueType.MAP_KEY) {
                decoder = data.getCodec().getMapKeyDecoder();
            } else if (data.getCommand().getOutParamType() == ValueType.MAP_VALUE) {
                decoder = data.getCodec().getMapValueDecoder();
            } else {
                decoder = data.getCodec().getValueDecoder();
            }
        }
        return decoder;
    }

    public ByteBuf readBytes(ByteBuf is) throws IOException {
        long l = readLong(is);
        if (l > Integer.MAX_VALUE) {
            throw new IllegalArgumentException(
                    "Java only supports arrays up to " + Integer.MAX_VALUE + " in size");
        }
        int size = (int) l;
        if (size == -1) {
            return null;
        }
        ByteBuf buffer = is.readSlice(size);
        int cr = is.readByte();
        int lf = is.readByte();
        if (cr != CR || lf != LF) {
            throw new IOException("Improper line ending: " + cr + ", " + lf);
        }
        return buffer;
    }

    public static long readLong(ByteBuf is) throws IOException {
        long size = 0;
        int sign = 1;
        int read = is.readByte();
        if (read == '-') {
            read = is.readByte();
            sign = -1;
        }
        do {
            if (read == CR) {
                if (is.readByte() == LF) {
                    break;
                }
            }
            int value = read - ZERO;
            if (value >= 0 && value < 10) {
                size *= 10;
                size += value;
            } else {
                throw new IOException("Invalid character in integer");
            }
            read = is.readByte();
        } while (true);
        return size * sign;
    }

}
