001    /**
002     * Licensed to the Apache Software Foundation (ASF) under one or more
003     * contributor license agreements.  See the NOTICE file distributed with
004     * this work for additional information regarding copyright ownership.
005     * The ASF licenses this file to You under the Apache License, Version 2.0
006     * (the "License"); you may not use this file except in compliance with
007     * the License.  You may obtain a copy of the License at
008     *
009     *      http://www.apache.org/licenses/LICENSE-2.0
010     *
011     * Unless required by applicable law or agreed to in writing, software
012     * distributed under the License is distributed on an "AS IS" BASIS,
013     * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014     * See the License for the specific language governing permissions and
015     * limitations under the License.
016     */
017    package org.apache.activemq.transport.stomp;
018    
019    import java.io.DataInput;
020    import java.io.DataInputStream;
021    import java.io.DataOutput;
022    import java.io.DataOutputStream;
023    import java.io.IOException;
024    import java.util.HashMap;
025    import java.util.Iterator;
026    import java.util.Map;
027    
028    import org.apache.activemq.util.ByteArrayInputStream;
029    import org.apache.activemq.util.ByteArrayOutputStream;
030    import org.apache.activemq.util.ByteSequence;
031    import org.apache.activemq.wireformat.WireFormat;
032    
033    /**
034     * Implements marshalling and unmarsalling the <a
035     * href="http://stomp.codehaus.org/">Stomp</a> protocol.
036     */
037    public class StompWireFormat implements WireFormat {
038    
039        private static final byte[] NO_DATA = new byte[] {};
040        private static final byte[] END_OF_FRAME = new byte[] {0, '\n'};
041    
042        private static final int MAX_COMMAND_LENGTH = 1024;
043        private static final int MAX_HEADER_LENGTH = 1024 * 10;
044        private static final int MAX_HEADERS = 1000;
045        private static final int MAX_DATA_LENGTH = 1024 * 1024 * 100;
046    
047        private int version = 1;
048    
049        public ByteSequence marshal(Object command) throws IOException {
050            ByteArrayOutputStream baos = new ByteArrayOutputStream();
051            DataOutputStream dos = new DataOutputStream(baos);
052            marshal(command, dos);
053            dos.close();
054            return baos.toByteSequence();
055        }
056    
057        public Object unmarshal(ByteSequence packet) throws IOException {
058            ByteArrayInputStream stream = new ByteArrayInputStream(packet);
059            DataInputStream dis = new DataInputStream(stream);
060            return unmarshal(dis);
061        }
062    
063        public void marshal(Object command, DataOutput os) throws IOException {
064            StompFrame stomp = (org.apache.activemq.transport.stomp.StompFrame)command;
065    
066            StringBuffer buffer = new StringBuffer();
067            buffer.append(stomp.getAction());
068            buffer.append(Stomp.NEWLINE);
069    
070            // Output the headers.
071            for (Iterator iter = stomp.getHeaders().entrySet().iterator(); iter.hasNext();) {
072                Map.Entry entry = (Map.Entry)iter.next();
073                buffer.append(entry.getKey());
074                buffer.append(Stomp.Headers.SEPERATOR);
075                buffer.append(entry.getValue());
076                buffer.append(Stomp.NEWLINE);
077            }
078    
079            // Add a newline to seperate the headers from the content.
080            buffer.append(Stomp.NEWLINE);
081    
082            os.write(buffer.toString().getBytes("UTF-8"));
083            os.write(stomp.getContent());
084            os.write(END_OF_FRAME);
085        }
086    
087        public Object unmarshal(DataInput in) throws IOException {
088    
089            try {
090                
091                // parse action
092                String action = parseAction(in);
093    
094                // Parse the headers
095                HashMap<String, String> headers = parseHeaders(in);
096    
097                // Read in the data part.
098                byte[] data = NO_DATA;
099                String contentLength = headers.get(Stomp.Headers.CONTENT_LENGTH);
100                if (contentLength != null) {
101    
102                    // Bless the client, he's telling us how much data to read in.
103                    int length = parseContentLength(contentLength);
104    
105                    data = new byte[length];
106                    in.readFully(data);
107    
108                    if (in.readByte() != 0) {
109                        throw new ProtocolException(Stomp.Headers.CONTENT_LENGTH + " bytes were read and " + "there was no trailing null byte", true);
110                    }
111    
112                } else {
113    
114                    // We don't know how much to read.. data ends when we hit a 0
115                    byte b;
116                    ByteArrayOutputStream baos = null;
117                    while ((b = in.readByte()) != 0) {
118    
119                        if (baos == null) {
120                            baos = new ByteArrayOutputStream();
121                        } else if (baos.size() > MAX_DATA_LENGTH) {
122                            throw new ProtocolException("The maximum data length was exceeded", true);
123                        }
124    
125                        baos.write(b);
126                    }
127    
128                    if (baos != null) {
129                        baos.close();
130                        data = baos.toByteArray();
131                    }
132    
133                }
134    
135                return new StompFrame(action, headers, data);
136    
137            } catch (ProtocolException e) {
138                return new StompFrameError(e);
139            }
140    
141        }
142    
143        private String readLine(DataInput in, int maxLength, String errorMessage) throws IOException {
144            byte b;
145            ByteArrayOutputStream baos = new ByteArrayOutputStream(maxLength);
146            while ((b = in.readByte()) != '\n') {
147                if (baos.size() > maxLength) {
148                    throw new ProtocolException(errorMessage, true);
149                }
150                baos.write(b);
151            }
152            baos.close();
153            ByteSequence sequence = baos.toByteSequence();
154            return new String(sequence.getData(), sequence.getOffset(), sequence.getLength(), "UTF-8");
155        }
156        
157        protected String parseAction(DataInput in) throws IOException {
158            String action = null;
159    
160            // skip white space to next real action line
161            while (true) {
162                action = readLine(in, MAX_COMMAND_LENGTH, "The maximum command length was exceeded");
163                if (action == null) {
164                    throw new IOException("connection was closed");
165                } else {
166                    action = action.trim();
167                    if (action.length() > 0) {
168                        break;
169                    }
170                }
171            }
172            return action;
173        }
174        
175        protected HashMap<String, String> parseHeaders(DataInput in) throws IOException {
176            HashMap<String, String> headers = new HashMap<String, String>(25);
177            while (true) {
178                String line = readLine(in, MAX_HEADER_LENGTH, "The maximum header length was exceeded");
179                if (line != null && line.trim().length() > 0) {
180    
181                    if (headers.size() > MAX_HEADERS) {
182                        throw new ProtocolException("The maximum number of headers was exceeded", true);
183                    }
184    
185                    try {
186                        int seperatorIndex = line.indexOf(Stomp.Headers.SEPERATOR);
187                        String name = line.substring(0, seperatorIndex).trim();
188                        String value = line.substring(seperatorIndex + 1, line.length()).trim();
189                        headers.put(name, value);
190                    } catch (Exception e) {
191                        throw new ProtocolException("Unable to parser header line [" + line + "]", true);
192                    }
193                } else {
194                    break;
195                }
196            }     
197            return headers;
198        }
199        
200        protected int parseContentLength(String contentLength) throws ProtocolException {
201            int length;
202            try {
203                length = Integer.parseInt(contentLength.trim());
204            } catch (NumberFormatException e) {
205                throw new ProtocolException("Specified content-length is not a valid integer", true);
206            }
207    
208            if (length > MAX_DATA_LENGTH) {
209                throw new ProtocolException("The maximum data length was exceeded", true);
210            }
211            
212            return length;
213        }
214    
215        public int getVersion() {
216            return version;
217        }
218    
219        public void setVersion(int version) {
220            this.version = version;
221        }
222    
223    }